Merge branch 'main' into feat/agent-node-v2

This commit is contained in:
Novice 2025-12-30 10:20:42 +08:00
commit ccabdbc83b
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
232 changed files with 18692 additions and 2696 deletions

8
.claude/settings.json Normal file
View File

@ -0,0 +1,8 @@
{
"enabledPlugins": {
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true
}
}

View File

@ -1,19 +0,0 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -0,0 +1,483 @@
---
name: component-refactoring
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
---
# Dify Component Refactoring Skill
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
## Quick Reference
### Commands (run from `web/`)
Use paths relative to `web/` (e.g., `app/components/...`).
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
```bash
cd web
# Generate refactoring prompt
pnpm refactor-component <path>
# Output refactoring analysis as JSON
pnpm refactor-component <path> --json
# Generate testing prompt (after refactoring)
pnpm analyze-component <path>
# Output testing analysis as JSON
pnpm analyze-component <path> --json
```
### Complexity Analysis
```bash
# Analyze component complexity
pnpm analyze-component <path> --json
# Key metrics to check:
# - complexity: normalized score 0-100 (target < 50)
# - maxComplexity: highest single function complexity
# - lineCount: total lines (target < 300)
```
### Complexity Score Interpretation
| Score | Level | Action |
|-------|-------|--------|
| 0-25 | 🟢 Simple | Ready for testing |
| 26-50 | 🟡 Medium | Consider minor refactoring |
| 51-75 | 🟠 Complex | **Refactor before testing** |
| 76-100 | 🔴 Very Complex | **Must refactor** |
## Core Refactoring Patterns
### Pattern 1: Extract Custom Hooks
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
```typescript
// ❌ Before: Complex state logic in component
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// 50+ lines of state management logic...
return <div>...</div>
}
// ✅ After: Extract to custom hook
// hooks/use-model-config.ts
export const useModelConfig = (appId: string) => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// Related state management logic here
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
}
// Component becomes cleaner
const Configuration: FC = () => {
const { modelConfig, setModelConfig } = useModelConfig(appId)
return <div>...</div>
}
```
**Dify Examples**:
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
- `web/app/components/app/configuration/debug/hooks.tsx`
- `web/app/components/workflow/hooks/use-workflow.ts`
### Pattern 2: Extract Sub-Components
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
```typescript
// ❌ Before: Monolithic JSX with multiple sections
const AppInfo = () => {
return (
<div>
{/* 100 lines of header UI */}
{/* 100 lines of operations UI */}
{/* 100 lines of modals */}
</div>
)
}
// ✅ After: Split into focused components
// app-info/
// ├── index.tsx (orchestration only)
// ├── app-header.tsx (header UI)
// ├── app-operations.tsx (operations UI)
// └── app-modals.tsx (modal management)
const AppInfo = () => {
const { showModal, setShowModal } = useAppInfoModals()
return (
<div>
<AppHeader appDetail={appDetail} />
<AppOperations onAction={handleAction} />
<AppModals show={showModal} onClose={() => setShowModal(null)} />
</div>
)
}
```
**Dify Examples**:
- `web/app/components/app/configuration/` directory structure
- `web/app/components/workflow/nodes/` per-node organization
### Pattern 3: Simplify Conditional Logic
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
```typescript
// ❌ Before: Deeply nested conditionals
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh />
case LanguagesSupported[7]:
return <TemplateChatJa />
default:
return <TemplateChatEn />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
// Another 15 lines...
}
// More conditions...
}, [appDetail, locale])
// ✅ After: Use lookup tables + early returns
const TEMPLATE_MAP = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
// ...
},
}
const Template = useMemo(() => {
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
if (!modeTemplates) return null
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
### Pattern 4: Extract API/Data Logic
**When**: Component directly handles API calls, data transformation, or complex async operations.
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
```typescript
// ❌ Before: API logic in component
const MCPServiceCard = () => {
const [basicAppConfig, setBasicAppConfig] = useState({})
useEffect(() => {
if (isBasicApp && appId) {
(async () => {
const res = await fetchAppDetail({ url: '/apps', id: appId })
setBasicAppConfig(res?.model_config || {})
})()
}
}, [appId, isBasicApp])
// More API-related logic...
}
// ✅ After: Extract to data hook using React Query
// use-app-config.ts
import { useQuery } from '@tanstack/react-query'
import { get } from '@/service/base'
const NAME_SPACE = 'appConfig'
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
return useQuery({
enabled: isBasicApp && !!appId,
queryKey: [NAME_SPACE, 'detail', appId],
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || {},
})
}
// Component becomes cleaner
const MCPServiceCard = () => {
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
// UI only
}
```
**React Query Best Practices in Dify**:
- Define `NAME_SPACE` for query key organization
- Use `enabled` option for conditional fetching
- Use `select` for data transformation
- Export invalidation hooks: `useInvalidXxx`
**Dify Examples**:
- `web/service/use-workflow.ts`
- `web/service/use-common.ts`
- `web/service/knowledge/use-dataset.ts`
- `web/service/knowledge/use-document.ts`
### Pattern 5: Extract Modal/Dialog Management
**When**: Component manages multiple modals with complex open/close states.
**Dify Convention**: Modals should be extracted with their state management.
```typescript
// ❌ Before: Multiple modal states in component
const AppInfo = () => {
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
const [showSwitchModal, setShowSwitchModal] = useState(false)
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
// 5+ more modal states...
}
// ✅ After: Extract to modal management hook
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
const useAppInfoModals = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
const closeModal = useCallback(() => setActiveModal(null), [])
return {
activeModal,
openModal,
closeModal,
isOpen: (type: ModalType) => activeModal === type,
}
}
```
### Pattern 6: Extract Form Logic
**When**: Complex form validation, submission handling, or field transformation.
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
```typescript
// ✅ Use existing form infrastructure
import { useAppForm } from '@/app/components/base/form'
const ConfigForm = () => {
const form = useAppForm({
defaultValues: { name: '', description: '' },
onSubmit: handleSubmit,
})
return <form.Provider>...</form.Provider>
}
```
## Dify-Specific Refactoring Guidelines
### 1. Context Provider Extraction
**When**: Component provides complex context values with multiple states.
```typescript
// ❌ Before: Large context value object
const value = {
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
// 50+ more properties...
}
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
// ✅ After: Split into domain-specific contexts
<ModelConfigProvider value={modelConfigValue}>
<DatasetConfigProvider value={datasetConfigValue}>
<UIConfigProvider value={uiConfigValue}>
{children}
</UIConfigProvider>
</DatasetConfigProvider>
</ModelConfigProvider>
```
**Dify Reference**: `web/context/` directory structure
### 2. Workflow Node Components
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
**Conventions**:
- Keep node logic in `use-interactions.ts`
- Extract panel UI to separate files
- Use `_base` components for common patterns
```
nodes/<node-type>/
├── index.tsx # Node registration
├── node.tsx # Node visual component
├── panel.tsx # Configuration panel
├── use-interactions.ts # Node-specific hooks
└── types.ts # Type definitions
```
### 3. Configuration Components
**When**: Refactoring app configuration components.
**Conventions**:
- Separate config sections into subdirectories
- Use existing patterns from `web/app/components/app/configuration/`
- Keep feature toggles in dedicated components
### 4. Tool/Plugin Components
**When**: Refactoring tool-related components (`web/app/components/tools/`).
**Conventions**:
- Follow existing modal patterns
- Use service hooks from `web/service/use-tools.ts`
- Keep provider-specific logic isolated
## Refactoring Workflow
### Step 1: Generate Refactoring Prompt
```bash
pnpm refactor-component <path>
```
This command will:
- Analyze component complexity and features
- Identify specific refactoring actions needed
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
- Provide detailed requirements based on detected patterns
### Step 2: Analyze Details
```bash
pnpm analyze-component <path> --json
```
Identify:
- Total complexity score
- Max function complexity
- Line count
- Features detected (state, effects, API, etc.)
### Step 3: Plan
Create a refactoring plan based on detected features:
| Detected Feature | Refactoring Action |
|------------------|-------------------|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
| `hasAPI: true` | Extract data/service hook |
| `hasEvents: true` (many) | Extract event handlers |
| `lineCount > 300` | Split into sub-components |
| `maxComplexity > 50` | Simplify conditional logic |
### Step 4: Execute Incrementally
1. **Extract one piece at a time**
2. **Run lint, type-check, and tests after each extraction**
3. **Verify functionality before next step**
```
For each extraction:
┌────────────────────────────────────────┐
│ 1. Extract code │
│ 2. Run: pnpm lint:fix │
│ 3. Run: pnpm type-check:tsgo │
│ 4. Run: pnpm test │
│ 5. Test functionality manually │
│ 6. PASS? → Next extraction │
│ FAIL? → Fix before continuing │
└────────────────────────────────────────┘
```
### Step 5: Verify
After refactoring:
```bash
# Re-run refactor command to verify improvements
pnpm refactor-component <path>
# If complexity < 25 and lines < 200, you'll see:
# ✅ COMPONENT IS WELL-STRUCTURED
# For detailed metrics:
pnpm analyze-component <path> --json
# Target metrics:
# - complexity < 50
# - lineCount < 300
# - maxComplexity < 30
```
## Common Mistakes to Avoid
### ❌ Over-Engineering
```typescript
// ❌ Too many tiny hooks
const useButtonText = () => useState('Click')
const useButtonDisabled = () => useState(false)
const useButtonLoading = () => useState(false)
// ✅ Cohesive hook with related state
const useButtonState = () => {
const [text, setText] = useState('Click')
const [disabled, setDisabled] = useState(false)
const [loading, setLoading] = useState(false)
return { text, setText, disabled, setDisabled, loading, setLoading }
}
```
### ❌ Breaking Existing Patterns
- Follow existing directory structures
- Maintain naming conventions
- Preserve export patterns for compatibility
### ❌ Premature Abstraction
- Only extract when there's clear complexity benefit
- Don't create abstractions for single-use code
- Keep refactored code in the same domain area
## References
### Dify Codebase Examples
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
- **Component splitting**: `web/app/components/app/configuration/`
- **Service hooks**: `web/service/use-*.ts`
- **Workflow patterns**: `web/app/components/workflow/hooks/`
- **Form patterns**: `web/app/components/base/form/`
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/testing/testing.md` - Testing specification

View File

@ -0,0 +1,493 @@
# Complexity Reduction Patterns
This document provides patterns for reducing cognitive complexity in Dify React components.
## Understanding Complexity
### SonarJS Cognitive Complexity
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
- **Total Complexity**: Sum of all functions' complexity in the file
- **Max Complexity**: Highest single function complexity
### What Increases Complexity
| Pattern | Complexity Impact |
|---------|-------------------|
| `if/else` | +1 per branch |
| Nested conditions | +1 per nesting level |
| `switch/case` | +1 per case |
| `for/while/do` | +1 per loop |
| `&&`/`||` chains | +1 per operator |
| Nested callbacks | +1 per nesting level |
| `try/catch` | +1 per catch |
| Ternary expressions | +1 per nesting |
## Pattern 1: Replace Conditionals with Lookup Tables
**Before** (complexity: ~15):
```typescript
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateChatJa appDetail={appDetail} />
default:
return <TemplateChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateAdvancedChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateAdvancedChatJa appDetail={appDetail} />
default:
return <TemplateAdvancedChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
// Similar pattern...
}
return null
}, [appDetail, locale])
```
**After** (complexity: ~3):
```typescript
// Define lookup table outside component
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
default: TemplateAdvancedChatEn,
},
[AppModeEnum.WORKFLOW]: {
[LanguagesSupported[1]]: TemplateWorkflowZh,
[LanguagesSupported[7]]: TemplateWorkflowJa,
default: TemplateWorkflowEn,
},
// ...
}
// Clean component logic
const Template = useMemo(() => {
if (!appDetail?.mode) return null
const templates = TEMPLATE_MAP[appDetail.mode]
if (!templates) return null
const TemplateComponent = templates[locale] ?? templates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
## Pattern 2: Use Early Returns
**Before** (complexity: ~10):
```typescript
const handleSubmit = () => {
if (isValid) {
if (hasChanges) {
if (isConnected) {
submitData()
} else {
showConnectionError()
}
} else {
showNoChangesMessage()
}
} else {
showValidationError()
}
}
```
**After** (complexity: ~4):
```typescript
const handleSubmit = () => {
if (!isValid) {
showValidationError()
return
}
if (!hasChanges) {
showNoChangesMessage()
return
}
if (!isConnected) {
showConnectionError()
return
}
submitData()
}
```
## Pattern 3: Extract Complex Conditions
**Before** (complexity: high):
```typescript
const canPublish = (() => {
if (mode !== AppModeEnum.COMPLETION) {
if (!isAdvancedMode)
return true
if (modelModeType === ModelModeType.completion) {
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
return false
return true
}
return true
}
return !promptEmpty
})()
```
**After** (complexity: lower):
```typescript
// Extract to named functions
const canPublishInCompletionMode = () => !promptEmpty
const canPublishInChatMode = () => {
if (!isAdvancedMode) return true
if (modelModeType !== ModelModeType.completion) return true
return hasSetBlockStatus.history && hasSetBlockStatus.query
}
// Clean main logic
const canPublish = mode === AppModeEnum.COMPLETION
? canPublishInCompletionMode()
: canPublishInChatMode()
```
## Pattern 4: Replace Chained Ternaries
**Before** (complexity: ~5):
```typescript
const statusText = serverActivated
? t('status.running')
: serverPublished
? t('status.inactive')
: appUnpublished
? t('status.unpublished')
: t('status.notConfigured')
```
**After** (complexity: ~2):
```typescript
const getStatusText = () => {
if (serverActivated) return t('status.running')
if (serverPublished) return t('status.inactive')
if (appUnpublished) return t('status.unpublished')
return t('status.notConfigured')
}
const statusText = getStatusText()
```
Or use lookup:
```typescript
const STATUS_TEXT_MAP = {
running: 'status.running',
inactive: 'status.inactive',
unpublished: 'status.unpublished',
notConfigured: 'status.notConfigured',
} as const
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
if (serverActivated) return 'running'
if (serverPublished) return 'inactive'
if (appUnpublished) return 'unpublished'
return 'notConfigured'
}
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
```
## Pattern 5: Flatten Nested Loops
**Before** (complexity: high):
```typescript
const processData = (items: Item[]) => {
const results: ProcessedItem[] = []
for (const item of items) {
if (item.isValid) {
for (const child of item.children) {
if (child.isActive) {
for (const prop of child.properties) {
if (prop.value !== null) {
results.push({
itemId: item.id,
childId: child.id,
propValue: prop.value,
})
}
}
}
}
}
}
return results
}
```
**After** (complexity: lower):
```typescript
// Use functional approach
const processData = (items: Item[]) => {
return items
.filter(item => item.isValid)
.flatMap(item =>
item.children
.filter(child => child.isActive)
.flatMap(child =>
child.properties
.filter(prop => prop.value !== null)
.map(prop => ({
itemId: item.id,
childId: child.id,
propValue: prop.value,
}))
)
)
}
```
## Pattern 6: Extract Event Handler Logic
**Before** (complexity: high in component):
```typescript
const Component = () => {
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
let newDatasets = data
if (data.find(item => !item.name)) {
const newSelected = produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const newItem = dataSets.find(i => i.id === item.id)
if (newItem)
draft[index] = newItem
}
})
})
setDataSets(newSelected)
newDatasets = newSelected
}
else {
setDataSets(data)
}
hideSelectDataSet()
// 40 more lines of logic...
}
return <div>...</div>
}
```
**After** (complexity: lower):
```typescript
// Extract to hook or utility
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
const normalizeSelection = (data: DataSet[]) => {
const hasUnloadedItem = data.some(item => !item.name)
if (!hasUnloadedItem) return data
return produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const existing = dataSets.find(i => i.id === item.id)
if (existing) draft[index] = existing
}
})
})
}
const hasSelectionChanged = (newData: DataSet[]) => {
return !isEqual(
newData.map(item => item.id),
dataSets.map(item => item.id)
)
}
return { normalizeSelection, hasSelectionChanged }
}
// Component becomes cleaner
const Component = () => {
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
const handleSelect = (data: DataSet[]) => {
if (!hasSelectionChanged(data)) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
const normalized = normalizeSelection(data)
setDataSets(normalized)
hideSelectDataSet()
}
return <div>...</div>
}
```
## Pattern 7: Reduce Boolean Logic Complexity
**Before** (complexity: ~8):
```typescript
const toggleDisabled = hasInsufficientPermissions
|| appUnpublished
|| missingStartNode
|| triggerModeDisabled
|| (isAdvancedApp && !currentWorkflow?.graph)
|| (isBasicApp && !basicAppConfig.updated_at)
```
**After** (complexity: ~3):
```typescript
// Extract meaningful boolean functions
const isAppReady = () => {
if (isAdvancedApp) return !!currentWorkflow?.graph
return !!basicAppConfig.updated_at
}
const hasRequiredPermissions = () => {
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
}
const canToggle = () => {
if (!hasRequiredPermissions()) return false
if (!isAppReady()) return false
if (missingStartNode) return false
if (triggerModeDisabled) return false
return true
}
const toggleDisabled = !canToggle()
```
## Pattern 8: Simplify useMemo/useCallback Dependencies
**Before** (complexity: multiple recalculations):
```typescript
const payload = useMemo(() => {
let parameters: Parameter[] = []
let outputParameters: OutputParameter[] = []
if (!published) {
parameters = (inputs || []).map((item) => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
outputParameters = (outputs || []).map((item) => ({
name: item.variable,
description: '',
type: item.value_type,
}))
}
else if (detail && detail.tool) {
parameters = (inputs || []).map((item) => ({
// Complex transformation...
}))
outputParameters = (outputs || []).map((item) => ({
// Complex transformation...
}))
}
return {
icon: detail?.icon || icon,
label: detail?.label || name,
// ...more fields
}
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
```
**After** (complexity: separated concerns):
```typescript
// Separate transformations
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
return useMemo(() => {
if (!published) {
return inputs.map(item => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
}
if (!detail?.tool) return []
return inputs.map(item => ({
name: item.variable,
required: item.required,
type: item.type === 'paragraph' ? 'string' : item.type,
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
}))
}, [inputs, detail, published])
}
// Component uses hook
const parameters = useParameterTransform(inputs, detail, published)
const outputParameters = useOutputTransform(outputs, detail, published)
const payload = useMemo(() => ({
icon: detail?.icon || icon,
label: detail?.label || name,
parameters,
outputParameters,
// ...
}), [detail, icon, name, parameters, outputParameters])
```
## Target Metrics After Refactoring
| Metric | Target |
|--------|--------|
| Total Complexity | < 50 |
| Max Function Complexity | < 30 |
| Function Length | < 30 lines |
| Nesting Depth | ≤ 3 levels |
| Conditional Chains | ≤ 3 conditions |

View File

@ -0,0 +1,477 @@
# Component Splitting Patterns
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
## When to Split Components
Split a component when you identify:
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
1. **Repeated patterns** - Similar UI structures used multiple times
1. **300+ lines** - Component exceeds manageable size
1. **Modal clusters** - Multiple modals rendered in one component
## Splitting Strategies
### Strategy 1: Section-Based Splitting
Identify visual sections and extract each as a component.
```typescript
// ❌ Before: Monolithic component (500+ lines)
const ConfigurationPage = () => {
return (
<div>
{/* Header Section - 50 lines */}
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher ... />
</div>
</div>
{/* Config Section - 200 lines */}
<div className="config">
<Config />
</div>
{/* Debug Section - 150 lines */}
<div className="debug">
<Debug ... />
</div>
{/* Modals Section - 100 lines */}
{showSelectDataSet && <SelectDataSet ... />}
{showHistoryModal && <EditHistoryModal ... />}
{showUseGPT4Confirm && <Confirm ... />}
</div>
)
}
// ✅ After: Split into focused components
// configuration/
// ├── index.tsx (orchestration)
// ├── configuration-header.tsx
// ├── configuration-content.tsx
// ├── configuration-debug.tsx
// └── configuration-modals.tsx
// configuration-header.tsx
interface ConfigurationHeaderProps {
isAdvancedMode: boolean
onPublish: () => void
}
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
isAdvancedMode,
onPublish,
}) => {
const { t } = useTranslation()
return (
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher onPublish={onPublish} />
</div>
</div>
)
}
// index.tsx (orchestration only)
const ConfigurationPage = () => {
const { modelConfig, setModelConfig } = useModelConfig()
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
<ConfigurationHeader
isAdvancedMode={isAdvancedMode}
onPublish={handlePublish}
/>
<ConfigurationContent
modelConfig={modelConfig}
onConfigChange={setModelConfig}
/>
{!isMobile && (
<ConfigurationDebug
inputs={inputs}
onSetting={handleSetting}
/>
)}
<ConfigurationModals
activeModal={activeModal}
onClose={closeModal}
/>
</div>
)
}
```
### Strategy 2: Conditional Block Extraction
Extract large conditional rendering blocks.
```typescript
// ❌ Before: Large conditional blocks
const AppInfo = () => {
return (
<div>
{expand ? (
<div className="expanded">
{/* 100 lines of expanded view */}
</div>
) : (
<div className="collapsed">
{/* 50 lines of collapsed view */}
</div>
)}
</div>
)
}
// ✅ After: Separate view components
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="expanded">
{/* Clean, focused expanded view */}
</div>
)
}
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="collapsed">
{/* Clean, focused collapsed view */}
</div>
)
}
const AppInfo = () => {
return (
<div>
{expand
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
}
</div>
)
}
```
### Strategy 3: Modal Extraction
Extract modals with their trigger logic.
```typescript
// ❌ Before: Multiple modals in one component
const AppInfo = () => {
const [showEdit, setShowEdit] = useState(false)
const [showDuplicate, setShowDuplicate] = useState(false)
const [showDelete, setShowDelete] = useState(false)
const [showSwitch, setShowSwitch] = useState(false)
const onEdit = async (data) => { /* 20 lines */ }
const onDuplicate = async (data) => { /* 20 lines */ }
const onDelete = async () => { /* 15 lines */ }
return (
<div>
{/* Main content */}
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
{showSwitch && <SwitchModal ... />}
</div>
)
}
// ✅ After: Modal manager component
// app-info-modals.tsx
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
interface AppInfoModalsProps {
appDetail: AppDetail
activeModal: ModalType
onClose: () => void
onSuccess: () => void
}
const AppInfoModals: FC<AppInfoModalsProps> = ({
appDetail,
activeModal,
onClose,
onSuccess,
}) => {
const handleEdit = async (data) => { /* logic */ }
const handleDuplicate = async (data) => { /* logic */ }
const handleDelete = async () => { /* logic */ }
return (
<>
{activeModal === 'edit' && (
<EditModal
appDetail={appDetail}
onConfirm={handleEdit}
onClose={onClose}
/>
)}
{activeModal === 'duplicate' && (
<DuplicateModal
appDetail={appDetail}
onConfirm={handleDuplicate}
onClose={onClose}
/>
)}
{activeModal === 'delete' && (
<DeleteConfirm
onConfirm={handleDelete}
onClose={onClose}
/>
)}
{activeModal === 'switch' && (
<SwitchModal
appDetail={appDetail}
onClose={onClose}
/>
)}
</>
)
}
// Parent component
const AppInfo = () => {
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
{/* Main content with openModal triggers */}
<Button onClick={() => openModal('edit')}>Edit</Button>
<AppInfoModals
appDetail={appDetail}
activeModal={activeModal}
onClose={closeModal}
onSuccess={handleSuccess}
/>
</div>
)
}
```
### Strategy 4: List Item Extraction
Extract repeated item rendering.
```typescript
// ❌ Before: Inline item rendering
const OperationsList = () => {
return (
<div>
{operations.map(op => (
<div key={op.id} className="operation-item">
<span className="icon">{op.icon}</span>
<span className="title">{op.title}</span>
<span className="description">{op.description}</span>
<button onClick={() => op.onClick()}>
{op.actionLabel}
</button>
{op.badge && <Badge>{op.badge}</Badge>}
{/* More complex rendering... */}
</div>
))}
</div>
)
}
// ✅ After: Extracted item component
interface OperationItemProps {
operation: Operation
onAction: (id: string) => void
}
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
return (
<div className="operation-item">
<span className="icon">{operation.icon}</span>
<span className="title">{operation.title}</span>
<span className="description">{operation.description}</span>
<button onClick={() => onAction(operation.id)}>
{operation.actionLabel}
</button>
{operation.badge && <Badge>{operation.badge}</Badge>}
</div>
)
}
const OperationsList = () => {
const handleAction = useCallback((id: string) => {
const op = operations.find(o => o.id === id)
op?.onClick()
}, [operations])
return (
<div>
{operations.map(op => (
<OperationItem
key={op.id}
operation={op}
onAction={handleAction}
/>
))}
</div>
)
}
```
## Directory Structure Patterns
### Pattern A: Flat Structure (Simple Components)
For components with 2-3 sub-components:
```
component-name/
├── index.tsx # Main component
├── sub-component-a.tsx
├── sub-component-b.tsx
└── types.ts # Shared types
```
### Pattern B: Nested Structure (Complex Components)
For components with many sub-components:
```
component-name/
├── index.tsx # Main orchestration
├── types.ts # Shared types
├── hooks/
│ ├── use-feature-a.ts
│ └── use-feature-b.ts
├── components/
│ ├── header/
│ │ └── index.tsx
│ ├── content/
│ │ └── index.tsx
│ └── modals/
│ └── index.tsx
└── utils/
└── helpers.ts
```
### Pattern C: Feature-Based Structure (Dify Standard)
Following Dify's existing patterns:
```
configuration/
├── index.tsx # Main page component
├── base/ # Base/shared components
│ ├── feature-panel/
│ ├── group-name/
│ └── operation-btn/
├── config/ # Config section
│ ├── index.tsx
│ ├── agent/
│ └── automatic/
├── dataset-config/ # Dataset section
│ ├── index.tsx
│ ├── card-item/
│ └── params-config/
├── debug/ # Debug section
│ ├── index.tsx
│ └── hooks.tsx
└── hooks/ # Shared hooks
└── use-advanced-prompt-config.ts
```
## Props Design
### Minimal Props Principle
Pass only what's needed:
```typescript
// ❌ Bad: Passing entire objects when only some fields needed
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
// ✅ Good: Destructure to minimum required
<ConfigHeader
appName={appDetail.name}
isAdvancedMode={modelConfig.isAdvanced}
onPublish={handlePublish}
/>
```
### Callback Props Pattern
Use callbacks for child-to-parent communication:
```typescript
// Parent
const Parent = () => {
const [value, setValue] = useState('')
return (
<Child
value={value}
onChange={setValue}
onSubmit={handleSubmit}
/>
)
}
// Child
interface ChildProps {
value: string
onChange: (value: string) => void
onSubmit: () => void
}
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
return (
<div>
<input value={value} onChange={e => onChange(e.target.value)} />
<button onClick={onSubmit}>Submit</button>
</div>
)
}
```
### Render Props for Flexibility
When sub-components need parent context:
```typescript
interface ListProps<T> {
items: T[]
renderItem: (item: T, index: number) => React.ReactNode
renderEmpty?: () => React.ReactNode
}
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
if (items.length === 0 && renderEmpty) {
return <>{renderEmpty()}</>
}
return (
<div>
{items.map((item, index) => renderItem(item, index))}
</div>
)
}
// Usage
<List
items={operations}
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
renderEmpty={() => <EmptyState message="No operations" />}
/>
```

View File

@ -0,0 +1,317 @@
# Hook Extraction Patterns
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
## When to Extract Hooks
Extract a custom hook when you identify:
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
1. **Business logic** - Data transformations, validations, or calculations
1. **Reusable patterns** - Logic that appears in multiple components
## Extraction Process
### Step 1: Identify State Groups
Look for state variables that are logically related:
```typescript
// ❌ These belong together - extract to hook
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
// These are model-related state that should be in useModelConfig()
```
### Step 2: Identify Related Effects
Find effects that modify the grouped state:
```typescript
// ❌ These effects belong with the state above
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties.mode
if (mode) {
const newModelConfig = produce(modelConfig, (draft) => {
draft.mode = mode
})
setModelConfig(newModelConfig)
}
}
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
```
### Step 3: Create the Hook
```typescript
// hooks/use-model-config.ts
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { ModelConfig } from '@/models/debug'
import { produce } from 'immer'
import { useEffect, useState } from 'react'
import { ModelModeType } from '@/types/app'
interface UseModelConfigParams {
initialConfig?: Partial<ModelConfig>
currModel?: { model_properties?: { mode?: ModelModeType } }
hasFetchedDetail: boolean
}
interface UseModelConfigReturn {
modelConfig: ModelConfig
setModelConfig: (config: ModelConfig) => void
completionParams: FormValue
setCompletionParams: (params: FormValue) => void
modelModeType: ModelModeType
}
export const useModelConfig = ({
initialConfig,
currModel,
hasFetchedDetail,
}: UseModelConfigParams): UseModelConfigReturn => {
const [modelConfig, setModelConfig] = useState<ModelConfig>({
provider: 'langgenius/openai/openai',
model_id: 'gpt-3.5-turbo',
mode: ModelModeType.unset,
// ... default values
...initialConfig,
})
const [completionParams, setCompletionParams] = useState<FormValue>({})
const modelModeType = modelConfig.mode
// Fill old app data missing model mode
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties?.mode
if (mode) {
setModelConfig(produce(modelConfig, (draft) => {
draft.mode = mode
}))
}
}
}, [hasFetchedDetail, modelModeType, currModel])
return {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
}
}
```
### Step 4: Update Component
```typescript
// Before: 50+ lines of state management
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
// ... lots of related state and effects
}
// After: Clean component
const Configuration: FC = () => {
const {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
} = useModelConfig({
currModel,
hasFetchedDetail,
})
// Component now focuses on UI
}
```
## Naming Conventions
### Hook Names
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
- Include domain: `useWorkflowVariables`, `useMCPServer`
### File Names
- Kebab-case: `use-model-config.ts`
- Place in `hooks/` subdirectory when multiple hooks exist
- Place alongside component for single-use hooks
### Return Type Names
- Suffix with `Return`: `UseModelConfigReturn`
- Suffix params with `Params`: `UseModelConfigParams`
## Common Hook Patterns in Dify
### 1. Data Fetching Hook (React Query)
```typescript
// Pattern: Use @tanstack/react-query for data fetching
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { get } from '@/service/base'
import { useInvalid } from '@/service/use-base'
const NAME_SPACE = 'appConfig'
// Query keys for cache management
export const appConfigQueryKeys = {
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
}
// Main data hook
export const useAppConfig = (appId: string) => {
return useQuery({
enabled: !!appId,
queryKey: appConfigQueryKeys.detail(appId),
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || null,
})
}
// Invalidation hook for refreshing data
export const useInvalidAppConfig = () => {
return useInvalid([NAME_SPACE])
}
// Usage in component
const Component = () => {
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
const invalidAppConfig = useInvalidAppConfig()
const handleRefresh = () => {
invalidAppConfig() // Invalidates cache and triggers refetch
}
return <div>...</div>
}
```
### 2. Form State Hook
```typescript
// Pattern: Form state + validation + submission
export const useConfigForm = (initialValues: ConfigFormValues) => {
const [values, setValues] = useState(initialValues)
const [errors, setErrors] = useState<Record<string, string>>({})
const [isSubmitting, setIsSubmitting] = useState(false)
const validate = useCallback(() => {
const newErrors: Record<string, string> = {}
if (!values.name) newErrors.name = 'Name is required'
setErrors(newErrors)
return Object.keys(newErrors).length === 0
}, [values])
const handleChange = useCallback((field: string, value: any) => {
setValues(prev => ({ ...prev, [field]: value }))
}, [])
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
if (!validate()) return
setIsSubmitting(true)
try {
await onSubmit(values)
} finally {
setIsSubmitting(false)
}
}, [values, validate])
return { values, errors, isSubmitting, handleChange, handleSubmit }
}
```
### 3. Modal State Hook
```typescript
// Pattern: Multiple modal management
type ModalType = 'edit' | 'delete' | 'duplicate' | null
export const useModalState = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const [modalData, setModalData] = useState<any>(null)
const openModal = useCallback((type: ModalType, data?: any) => {
setActiveModal(type)
setModalData(data)
}, [])
const closeModal = useCallback(() => {
setActiveModal(null)
setModalData(null)
}, [])
return {
activeModal,
modalData,
openModal,
closeModal,
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
}
}
```
### 4. Toggle/Boolean Hook
```typescript
// Pattern: Boolean state with convenience methods
export const useToggle = (initialValue = false) => {
const [value, setValue] = useState(initialValue)
const toggle = useCallback(() => setValue(v => !v), [])
const setTrue = useCallback(() => setValue(true), [])
const setFalse = useCallback(() => setValue(false), [])
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
}
// Usage
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
```
## Testing Extracted Hooks
After extraction, test hooks in isolation:
```typescript
// use-model-config.spec.ts
import { renderHook, act } from '@testing-library/react'
import { useModelConfig } from './use-model-config'
describe('useModelConfig', () => {
it('should initialize with default values', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: false,
}))
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
expect(result.current.modelModeType).toBe(ModelModeType.unset)
})
it('should update model config', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: true,
}))
act(() => {
result.current.setModelConfig({
...result.current.modelConfig,
model_id: 'gpt-4',
})
})
expect(result.current.modelConfig.model_id).toBe('gpt-4')
})
})
```

View File

@ -1,13 +1,13 @@
---
name: Dify Frontend Testing
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
name: frontend-testing
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
---
# Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill
@ -15,7 +15,7 @@ Apply this skill when the user:
- Asks to **write tests** for a component, hook, or utility
- Asks to **review existing tests** for completeness
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
- Requests **test coverage** improvement
- Uses `pnpm analyze-component` output as context
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
@ -33,9 +33,9 @@ Apply this skill when the user:
| Tool | Version | Purpose |
|------|---------|---------|
| Jest | 29.7 | Test runner |
| Vitest | 4.0.16 | Test runner |
| React Testing Library | 16.0 | Component testing |
| happy-dom | - | Test environment |
| jsdom | - | Test environment |
| nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety |
@ -46,13 +46,13 @@ Apply this skill when the user:
pnpm test
# Watch mode
pnpm test -- --watch
pnpm test:watch
# Run specific file
pnpm test -- path/to/file.spec.tsx
pnpm test path/to/file.spec.tsx
# Generate coverage report
pnpm test -- --coverage
pnpm test:coverage
# Analyze component complexity
pnpm analyze-component <path>
@ -77,9 +77,9 @@ import Component from './index'
// import { ChildComponent } from './child-component'
// ✅ Mock external dependencies only
jest.mock('@/service/api')
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
vi.mock('@/service/api')
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: vi.fn() }),
usePathname: () => '/test',
}))
@ -88,7 +88,7 @@ let mockSharedState = false
describe('ComponentName', () => {
beforeEach(() => {
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
mockSharedState = false // ✅ Reset shared state
})
@ -117,7 +117,7 @@ describe('ComponentName', () => {
// User Interactions
describe('User Interactions', () => {
it('should handle click events', () => {
const handleClick = jest.fn()
const handleClick = vi.fn()
render(<Component onClick={handleClick} />)
fireEvent.click(screen.getByRole('button'))
@ -155,7 +155,7 @@ describe('ComponentName', () => {
For each file:
┌────────────────────────────────────────┐
│ 1. Write test │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 2. Run: pnpm test <file>.spec.tsx
│ 3. PASS? → Mark complete, next file │
│ FAIL? → Fix first, then continue │
└────────────────────────────────────────┘
@ -178,7 +178,7 @@ Process in this order for multi-file testing:
- **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
## Testing Strategy
@ -289,17 +289,18 @@ For each test file generated, aim for:
- ✅ **>95%** branch coverage
- ✅ **>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
## Detailed Guides
For more detailed information, refer to:
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `guides/mocking.md` - Mock patterns and best practices
- `guides/async-testing.md` - Async operations and API calls
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
- `guides/common-patterns.md` - Frequently used testing patterns
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `references/mocking.md` - Mock patterns and best practices
- `references/async-testing.md` - Async operations and API calls
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `references/common-patterns.md` - Frequently used testing patterns
- `references/checklist.md` - Test generation checklist and validation steps
## Authoritative References
@ -315,7 +316,7 @@ For more detailed information, refer to:
### Project Configuration
- `web/jest.config.ts` - Jest configuration
- `web/jest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)
- `web/vitest.config.ts` - Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/scripts/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -23,14 +23,14 @@ import userEvent from '@testing-library/user-event'
// ============================================================================
// Mocks
// ============================================================================
// WHY: Mocks must be hoisted to top of file (Jest requirement).
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (automatically mocked)
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// jest.mock('react-i18next', () => ({
// vi.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
@ -43,17 +43,17 @@ import userEvent from '@testing-library/user-event'
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = jest.fn()
// jest.mock('next/navigation', () => ({
// const mockPush = vi.fn()
// vi.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))
// API services (if component fetches data)
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
// jest.mock('@/service/api')
// vi.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// const mockedApi = vi.mocked(api)
// Shared mock state (for portal/dropdown components)
// WHY: Portal components like PortalToFollowElem need shared state between
@ -98,7 +98,7 @@ describe('ComponentName', () => {
// - Prevents mock call history from leaking between tests
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
// mockOpenState = false
})
@ -155,7 +155,7 @@ describe('ComponentName', () => {
// - userEvent simulates real user behavior (focus, hover, then click)
// - fireEvent is lower-level, doesn't trigger all browser events
// const user = userEvent.setup()
// const handleClick = jest.fn()
// const handleClick = vi.fn()
// render(<ComponentName onClick={handleClick} />)
//
// await user.click(screen.getByRole('button'))
@ -165,7 +165,7 @@ describe('ComponentName', () => {
it('should call onChange when value changes', async () => {
// const user = userEvent.setup()
// const handleChange = jest.fn()
// const handleChange = vi.fn()
// render(<ComponentName onChange={handleChange} />)
//
// await user.type(screen.getByRole('textbox'), 'new value')
@ -198,7 +198,7 @@ describe('ComponentName', () => {
})
// --------------------------------------------------------------------------
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
// Async Operations (if component fetches data - useQuery, fetch)
// --------------------------------------------------------------------------
// WHY: Async operations have 3 states users experience: loading, success, error
describe('Async Operations', () => {

View File

@ -15,9 +15,9 @@ import { renderHook, act, waitFor } from '@testing-library/react'
// ============================================================================
// API services (if hook fetches data)
// jest.mock('@/service/api')
// vi.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// const mockedApi = vi.mocked(api)
// ============================================================================
// Test Helpers
@ -38,7 +38,7 @@ import { renderHook, act, waitFor } from '@testing-library/react'
describe('useHookName', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
// --------------------------------------------------------------------------
@ -145,7 +145,7 @@ describe('useHookName', () => {
// --------------------------------------------------------------------------
describe('Side Effects', () => {
it('should call callback when value changes', () => {
// const callback = jest.fn()
// const callback = vi.fn()
// const { result } = renderHook(() => useHookName({ onChange: callback }))
//
// act(() => {
@ -156,9 +156,9 @@ describe('useHookName', () => {
})
it('should cleanup on unmount', () => {
// const cleanup = jest.fn()
// jest.spyOn(window, 'addEventListener')
// jest.spyOn(window, 'removeEventListener')
// const cleanup = vi.fn()
// vi.spyOn(window, 'addEventListener')
// vi.spyOn(window, 'removeEventListener')
//
// const { unmount } = renderHook(() => useHookName())
//

View File

@ -49,7 +49,7 @@ import userEvent from '@testing-library/user-event'
it('should submit form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
const onSubmit = vi.fn()
render(<Form onSubmit={onSubmit} />)
@ -77,15 +77,15 @@ it('should submit form', async () => {
```typescript
describe('Debounced Search', () => {
beforeEach(() => {
jest.useFakeTimers()
vi.useFakeTimers()
})
afterEach(() => {
jest.useRealTimers()
vi.useRealTimers()
})
it('should debounce search input', async () => {
const onSearch = jest.fn()
const onSearch = vi.fn()
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
// Type in the input
@ -95,7 +95,7 @@ describe('Debounced Search', () => {
expect(onSearch).not.toHaveBeenCalled()
// Advance timers
jest.advanceTimersByTime(300)
vi.advanceTimersByTime(300)
// Now search is called
expect(onSearch).toHaveBeenCalledWith('query')
@ -107,8 +107,8 @@ describe('Debounced Search', () => {
```typescript
it('should retry on failure', async () => {
jest.useFakeTimers()
const fetchData = jest.fn()
vi.useFakeTimers()
const fetchData = vi.fn()
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ data: 'success' })
@ -120,7 +120,7 @@ it('should retry on failure', async () => {
})
// Advance timer for retry
jest.advanceTimersByTime(1000)
vi.advanceTimersByTime(1000)
// Second call succeeds
await waitFor(() => {
@ -128,7 +128,7 @@ it('should retry on failure', async () => {
expect(screen.getByText('success')).toBeInTheDocument()
})
jest.useRealTimers()
vi.useRealTimers()
})
```
@ -136,19 +136,19 @@ it('should retry on failure', async () => {
```typescript
// Run all pending timers
jest.runAllTimers()
vi.runAllTimers()
// Run only pending timers (not new ones created during execution)
jest.runOnlyPendingTimers()
vi.runOnlyPendingTimers()
// Advance by specific time
jest.advanceTimersByTime(1000)
vi.advanceTimersByTime(1000)
// Get current fake time
jest.now()
Date.now()
// Clear all timers
jest.clearAllTimers()
vi.clearAllTimers()
```
## API Testing Patterns
@ -158,7 +158,7 @@ jest.clearAllTimers()
```typescript
describe('DataFetcher', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
it('should show loading state', () => {
@ -241,7 +241,7 @@ it('should submit form and show success', async () => {
```typescript
it('should fetch data on mount', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
render(<ComponentWithEffect fetchData={fetchData} />)
@ -255,7 +255,7 @@ it('should fetch data on mount', async () => {
```typescript
it('should refetch when id changes', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
@ -276,8 +276,8 @@ it('should refetch when id changes', async () => {
```typescript
it('should cleanup subscription on unmount', () => {
const subscribe = jest.fn()
const unsubscribe = jest.fn()
const subscribe = vi.fn()
const unsubscribe = vi.fn()
subscribe.mockReturnValue(unsubscribe)
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
@ -332,14 +332,14 @@ expect(description).toBeInTheDocument()
```typescript
// Bad - fake timers don't work well with real Promises
jest.useFakeTimers()
vi.useFakeTimers()
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
}) // May timeout!
// Good - use runAllTimers or advanceTimersByTime
jest.useFakeTimers()
vi.useFakeTimers()
render(<Component />)
jest.runAllTimers()
vi.runAllTimers()
expect(screen.getByText('Data')).toBeInTheDocument()
```

View File

@ -74,9 +74,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
@ -114,15 +114,15 @@ For the current file being tested:
**Run these checks after EACH test file, not just at the end:**
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Fix any failures immediately
- [ ] Mark file as complete in todo list
- [ ] Only then proceed to next file
### After All Files Complete
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
- [ ] Check coverage report: `pnpm test -- --coverage`
- [ ] Run full directory test: `pnpm test path/to/directory/`
- [ ] Check coverage report: `pnpm test:coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo`
@ -132,10 +132,10 @@ For the current file being tested:
```typescript
// ❌ Mock doesn't match actual behavior
jest.mock('./Component', () => () => <div>Mocked</div>)
vi.mock('./Component', () => () => <div>Mocked</div>)
// ✅ Mock matches actual conditional logic
jest.mock('./Component', () => ({ isOpen }: any) =>
vi.mock('./Component', () => ({ isOpen }: any) =>
isOpen ? <div>Content</div> : null
)
```
@ -145,7 +145,7 @@ jest.mock('./Component', () => ({ isOpen }: any) =>
```typescript
// ❌ Shared state not reset
let mockState = false
jest.mock('./useHook', () => () => mockState)
vi.mock('./useHook', () => () => mockState)
// ✅ Reset in beforeEach
beforeEach(() => {
@ -186,16 +186,16 @@ Always test these scenarios:
```bash
# Run specific test
pnpm test -- path/to/file.spec.tsx
pnpm test path/to/file.spec.tsx
# Run with coverage
pnpm test -- --coverage path/to/file.spec.tsx
pnpm test:coverage path/to/file.spec.tsx
# Watch mode
pnpm test -- --watch path/to/file.spec.tsx
pnpm test:watch path/to/file.spec.tsx
# Update snapshots (use sparingly)
pnpm test -- -u path/to/file.spec.tsx
pnpm test -u path/to/file.spec.tsx
# Analyze component
pnpm analyze-component path/to/component.tsx

View File

@ -126,7 +126,7 @@ describe('Counter', () => {
describe('ControlledInput', () => {
it('should call onChange with new value', async () => {
const user = userEvent.setup()
const handleChange = jest.fn()
const handleChange = vi.fn()
render(<ControlledInput value="" onChange={handleChange} />)
@ -136,7 +136,7 @@ describe('ControlledInput', () => {
})
it('should display controlled value', () => {
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
expect(screen.getByRole('textbox')).toHaveValue('controlled')
})
@ -195,7 +195,7 @@ describe('ItemList', () => {
it('should handle item selection', async () => {
const user = userEvent.setup()
const onSelect = jest.fn()
const onSelect = vi.fn()
render(<ItemList items={items} onSelect={onSelect} />)
@ -217,20 +217,20 @@ describe('ItemList', () => {
```typescript
describe('Modal', () => {
it('should not render when closed', () => {
render(<Modal isOpen={false} onClose={jest.fn()} />)
render(<Modal isOpen={false} onClose={vi.fn()} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render when open', () => {
render(<Modal isOpen={true} onClose={jest.fn()} />)
render(<Modal isOpen={true} onClose={vi.fn()} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should call onClose when clicking overlay', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
@ -241,7 +241,7 @@ describe('Modal', () => {
it('should call onClose when pressing Escape', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
@ -254,7 +254,7 @@ describe('Modal', () => {
const user = userEvent.setup()
render(
<Modal isOpen={true} onClose={jest.fn()}>
<Modal isOpen={true} onClose={vi.fn()}>
<button>First</button>
<button>Second</button>
</Modal>
@ -279,7 +279,7 @@ describe('Modal', () => {
describe('LoginForm', () => {
it('should submit valid form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
const onSubmit = vi.fn()
render(<LoginForm onSubmit={onSubmit} />)
@ -296,7 +296,7 @@ describe('LoginForm', () => {
it('should show validation errors', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
render(<LoginForm onSubmit={vi.fn()} />)
// Submit empty form
await user.click(screen.getByRole('button', { name: /sign in/i }))
@ -308,7 +308,7 @@ describe('LoginForm', () => {
it('should validate email format', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
render(<LoginForm onSubmit={vi.fn()} />)
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
await user.click(screen.getByRole('button', { name: /sign in/i }))
@ -318,7 +318,7 @@ describe('LoginForm', () => {
it('should disable submit button while submitting', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
render(<LoginForm onSubmit={onSubmit} />)
@ -407,7 +407,7 @@ it('test 1', () => {
// Good - cleanup is automatic with RTL, but reset mocks
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
```

View File

@ -23,7 +23,7 @@ import NodeConfigPanel from './node-config-panel'
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
// Mock workflow context
jest.mock('@/app/components/workflow/hooks', () => ({
vi.mock('@/app/components/workflow/hooks', () => ({
useWorkflowStore: () => mockWorkflowStore,
useNodesInteractions: () => mockNodesInteractions,
}))
@ -31,21 +31,21 @@ jest.mock('@/app/components/workflow/hooks', () => ({
let mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
updateNode: vi.fn(),
}
let mockNodesInteractions = {
handleNodeSelect: jest.fn(),
handleNodeDelete: jest.fn(),
handleNodeSelect: vi.fn(),
handleNodeDelete: vi.fn(),
}
describe('NodeConfigPanel', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
updateNode: vi.fn(),
}
})
@ -161,23 +161,23 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DocumentUploader from './document-uploader'
jest.mock('@/service/datasets', () => ({
uploadDocument: jest.fn(),
parseDocument: jest.fn(),
vi.mock('@/service/datasets', () => ({
uploadDocument: vi.fn(),
parseDocument: vi.fn(),
}))
import * as datasetService from '@/service/datasets'
const mockedService = datasetService as jest.Mocked<typeof datasetService>
const mockedService = vi.mocked(datasetService)
describe('DocumentUploader', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
describe('File Upload', () => {
it('should accept valid file types', async () => {
const user = userEvent.setup()
const onUpload = jest.fn()
const onUpload = vi.fn()
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
render(<DocumentUploader onUpload={onUpload} />)
@ -326,14 +326,14 @@ describe('DocumentList', () => {
describe('Search & Filtering', () => {
it('should filter by search query', async () => {
const user = userEvent.setup()
jest.useFakeTimers()
vi.useFakeTimers()
render(<DocumentList datasetId="ds-1" />)
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
// Debounce
jest.advanceTimersByTime(300)
vi.advanceTimersByTime(300)
await waitFor(() => {
expect(mockedService.getDocuments).toHaveBeenCalledWith(
@ -342,7 +342,7 @@ describe('DocumentList', () => {
)
})
jest.useRealTimers()
vi.useRealTimers()
})
})
})
@ -367,13 +367,13 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppConfigForm from './app-config-form'
jest.mock('@/service/apps', () => ({
updateAppConfig: jest.fn(),
getAppConfig: jest.fn(),
vi.mock('@/service/apps', () => ({
updateAppConfig: vi.fn(),
getAppConfig: vi.fn(),
}))
import * as appService from '@/service/apps'
const mockedService = appService as jest.Mocked<typeof appService>
const mockedService = vi.mocked(appService)
describe('AppConfigForm', () => {
const defaultConfig = {
@ -384,7 +384,7 @@ describe('AppConfigForm', () => {
}
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
})

View File

@ -19,8 +19,8 @@
```typescript
// ❌ WRONG: Don't mock base components
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
@ -41,20 +41,23 @@ Only mock these categories:
| Location | Purpose |
|----------|---------|
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
| Test file | Test-specific mocks, inline with `jest.mock()` |
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
## Essential Mocks
### 1. i18n (Auto-loaded via Shared Mock)
### 1. i18n (Auto-loaded via Global Mock)
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
```typescript
jest.mock('react-i18next', () => ({
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
@ -69,15 +72,15 @@ jest.mock('react-i18next', () => ({
### 2. Next.js Router
```typescript
const mockPush = jest.fn()
const mockReplace = jest.fn()
const mockPush = vi.fn()
const mockReplace = vi.fn()
jest.mock('next/navigation', () => ({
vi.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
replace: mockReplace,
back: jest.fn(),
prefetch: jest.fn(),
back: vi.fn(),
prefetch: vi.fn(),
}),
usePathname: () => '/current-path',
useSearchParams: () => new URLSearchParams('?key=value'),
@ -85,7 +88,7 @@ jest.mock('next/navigation', () => ({
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
it('should navigate on click', () => {
@ -102,7 +105,7 @@ describe('Component', () => {
// ⚠️ Important: Use shared state for components that depend on each other
let mockPortalOpenState = false
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, ...props }: any) => {
mockPortalOpenState = open || false // Update shared state
return <div data-testid="portal" data-open={open}>{children}</div>
@ -119,7 +122,7 @@ jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
mockPortalOpenState = false // ✅ Reset shared state
})
})
@ -130,13 +133,13 @@ describe('Component', () => {
```typescript
import * as api from '@/service/api'
jest.mock('@/service/api')
vi.mock('@/service/api')
const mockedApi = api as jest.Mocked<typeof api>
const mockedApi = vi.mocked(api)
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
// Setup default mock implementation
mockedApi.fetchData.mockResolvedValue({ data: [] })
@ -239,32 +242,9 @@ describe('Component with Context', () => {
})
```
### 7. SWR / React Query
### 7. React Query
```typescript
// SWR
jest.mock('swr', () => ({
__esModule: true,
default: jest.fn(),
}))
import useSWR from 'swr'
const mockedUseSWR = useSWR as jest.Mock
describe('Component with SWR', () => {
it('should show loading state', () => {
mockedUseSWR.mockReturnValue({
data: undefined,
error: undefined,
isLoading: true,
})
render(<Component />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
})
// React Query
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
const createTestQueryClient = () => new QueryClient({

View File

@ -35,7 +35,7 @@ When testing a **single component, hook, or utility**:
2. Run `pnpm analyze-component <path>` (if available)
3. Check complexity score and features detected
4. Write the test file
5. Run test: `pnpm test -- <file>.spec.tsx`
5. Run test: `pnpm test <file>.spec.tsx`
6. Fix any failures
7. Verify coverage meets goals (100% function, >95% branch)
```
@ -80,7 +80,7 @@ Process files in this recommended order:
```
┌─────────────────────────────────────────────┐
│ 1. Write test file │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 2. Run: pnpm test <file>.spec.tsx
│ 3. If FAIL → Fix immediately, re-run │
│ 4. If PASS → Mark complete in todo list │
│ 5. ONLY THEN proceed to next file │
@ -95,10 +95,10 @@ After all individual tests pass:
```bash
# Run all tests in the directory together
pnpm test -- path/to/directory/
pnpm test path/to/directory/
# Check coverage
pnpm test -- --coverage path/to/directory/
pnpm test:coverage path/to/directory/
```
## Component Complexity Guidelines
@ -201,9 +201,9 @@ Run pnpm test ← Multiple failures, hard to debug
```
# GOOD: Incremental with verification
Write component-a.spec.tsx
Run pnpm test -- component-a.spec.tsx ✅
Run pnpm test component-a.spec.tsx ✅
Write component-b.spec.tsx
Run pnpm test -- component-b.spec.tsx ✅
Run pnpm test component-b.spec.tsx ✅
...continue...
```

1
.codex/skills Symbolic link
View File

@ -0,0 +1 @@
../.claude/skills

View File

@ -6,6 +6,9 @@
"context": "..",
"dockerfile": "Dockerfile"
},
"mounts": [
"source=dify-dev-tmp,target=/tmp,type=volume"
],
"features": {
"ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true,
@ -34,19 +37,13 @@
},
"postStartCommand": "./.devcontainer/post_start_command.sh",
"postCreateCommand": "./.devcontainer/post_create_command.sh"
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "python --version",
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}
}

View File

@ -1,12 +1,13 @@
#!/bin/bash
WORKSPACE_ROOT=$(pwd)
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
corepack enable
cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

313
.github/CODEOWNERS vendored
View File

@ -6,229 +6,244 @@
* @crazywoola @laipz8200 @Yeuoly
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola
# Docs
/docs/ @crazywoola
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
/api/ @QuantumGhost
# Backend - MCP
api/core/mcp/ @Nov1c444
api/core/entities/mcp_provider.py @Nov1c444
api/services/tools/mcp_tools_manage_service.py @Nov1c444
api/controllers/mcp/ @Nov1c444
api/controllers/console/app/mcp_server.py @Nov1c444
api/tests/**/*mcp* @Nov1c444
/api/core/mcp/ @Nov1c444
/api/core/entities/mcp_provider.py @Nov1c444
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
/api/controllers/mcp/ @Nov1c444
/api/controllers/console/app/mcp_server.py @Nov1c444
/api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
api/core/workflow/graph/ @laipz8200 @QuantumGhost
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost
/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
/api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ @Nov1c444
api/core/workflow/nodes/iteration/ @Nov1c444
api/core/workflow/nodes/loop/ @Nov1c444
api/core/workflow/nodes/llm/ @Nov1c444
/api/core/workflow/nodes/agent/ @Nov1c444
/api/core/workflow/nodes/iteration/ @Nov1c444
/api/core/workflow/nodes/loop/ @Nov1c444
/api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong
api/services/rag_pipeline/ @JohnJyong
api/services/dataset_service.py @JohnJyong
api/services/knowledge_service.py @JohnJyong
api/services/external_knowledge_service.py @JohnJyong
api/services/hit_testing_service.py @JohnJyong
api/services/metadata_service.py @JohnJyong
api/services/vector_service.py @JohnJyong
api/services/entities/knowledge_entities/ @JohnJyong
api/services/entities/external_knowledge_entities/ @JohnJyong
api/controllers/console/datasets/ @JohnJyong
api/controllers/service_api/dataset/ @JohnJyong
api/models/dataset.py @JohnJyong
api/tasks/rag_pipeline/ @JohnJyong
api/tasks/add_document_to_index_task.py @JohnJyong
api/tasks/batch_clean_document_task.py @JohnJyong
api/tasks/clean_document_task.py @JohnJyong
api/tasks/clean_notion_document_task.py @JohnJyong
api/tasks/document_indexing_task.py @JohnJyong
api/tasks/document_indexing_sync_task.py @JohnJyong
api/tasks/document_indexing_update_task.py @JohnJyong
api/tasks/duplicate_document_indexing_task.py @JohnJyong
api/tasks/recover_document_indexing_task.py @JohnJyong
api/tasks/remove_document_from_index_task.py @JohnJyong
api/tasks/retry_document_indexing_task.py @JohnJyong
api/tasks/sync_website_document_indexing_task.py @JohnJyong
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
api/tasks/create_segment_to_index_task.py @JohnJyong
api/tasks/delete_segment_from_index_task.py @JohnJyong
api/tasks/disable_segment_from_index_task.py @JohnJyong
api/tasks/disable_segments_from_index_task.py @JohnJyong
api/tasks/enable_segment_to_index_task.py @JohnJyong
api/tasks/enable_segments_to_index_task.py @JohnJyong
api/tasks/clean_dataset_task.py @JohnJyong
api/tasks/deal_dataset_index_update_task.py @JohnJyong
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
/api/core/rag/ @JohnJyong
/api/services/rag_pipeline/ @JohnJyong
/api/services/dataset_service.py @JohnJyong
/api/services/knowledge_service.py @JohnJyong
/api/services/external_knowledge_service.py @JohnJyong
/api/services/hit_testing_service.py @JohnJyong
/api/services/metadata_service.py @JohnJyong
/api/services/vector_service.py @JohnJyong
/api/services/entities/knowledge_entities/ @JohnJyong
/api/services/entities/external_knowledge_entities/ @JohnJyong
/api/controllers/console/datasets/ @JohnJyong
/api/controllers/service_api/dataset/ @JohnJyong
/api/models/dataset.py @JohnJyong
/api/tasks/rag_pipeline/ @JohnJyong
/api/tasks/add_document_to_index_task.py @JohnJyong
/api/tasks/batch_clean_document_task.py @JohnJyong
/api/tasks/clean_document_task.py @JohnJyong
/api/tasks/clean_notion_document_task.py @JohnJyong
/api/tasks/document_indexing_task.py @JohnJyong
/api/tasks/document_indexing_sync_task.py @JohnJyong
/api/tasks/document_indexing_update_task.py @JohnJyong
/api/tasks/duplicate_document_indexing_task.py @JohnJyong
/api/tasks/recover_document_indexing_task.py @JohnJyong
/api/tasks/remove_document_from_index_task.py @JohnJyong
/api/tasks/retry_document_indexing_task.py @JohnJyong
/api/tasks/sync_website_document_indexing_task.py @JohnJyong
/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
/api/tasks/create_segment_to_index_task.py @JohnJyong
/api/tasks/delete_segment_from_index_task.py @JohnJyong
/api/tasks/disable_segment_from_index_task.py @JohnJyong
/api/tasks/disable_segments_from_index_task.py @JohnJyong
/api/tasks/enable_segment_to_index_task.py @JohnJyong
/api/tasks/enable_segments_to_index_task.py @JohnJyong
/api/tasks/clean_dataset_task.py @JohnJyong
/api/tasks/deal_dataset_index_update_task.py @JohnJyong
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins
api/core/plugin/ @Mairuis @Yeuoly @Stream29
api/services/plugin/ @Mairuis @Yeuoly @Stream29
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook
api/controllers/trigger/ @Mairuis @Yeuoly
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
api/core/trigger/ @Mairuis @Yeuoly
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
api/services/trigger/ @Mairuis @Yeuoly
api/models/trigger.py @Mairuis @Yeuoly
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/libs/schedule_utils.py @Mairuis @Yeuoly
api/services/workflow/scheduler.py @Mairuis @Yeuoly
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
/api/controllers/trigger/ @Mairuis @Yeuoly
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
/api/core/trigger/ @Mairuis @Yeuoly
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
/api/services/trigger/ @Mairuis @Yeuoly
/api/models/trigger.py @Mairuis @Yeuoly
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
/api/libs/schedule_utils.py @Mairuis @Yeuoly
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow
api/services/async_workflow_service.py @Mairuis @Yeuoly
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
/api/services/async_workflow_service.py @Mairuis @Yeuoly
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing
api/services/billing_service.py @hj24 @zyssyz123
api/controllers/console/billing/ @hj24 @zyssyz123
/api/services/billing_service.py @hj24 @zyssyz123
/api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise
api/configs/enterprise/ @GarfieldDai @GareArc
api/services/enterprise/ @GarfieldDai @GareArc
api/services/feature_service.py @GarfieldDai @GareArc
api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
/api/configs/enterprise/ @GarfieldDai @GareArc
/api/services/enterprise/ @GarfieldDai @GareArc
/api/services/feature_service.py @GarfieldDai @GareArc
/api/controllers/console/feature.py @GarfieldDai @GareArc
/api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200
/api/migrations/ @snakevash @laipz8200 @MRZHUH
# Backend - Vector DB Middleware
/api/configs/middleware/vdb/* @JohnJyong
# Frontend
web/ @iamjoel
/web/ @iamjoel
# Frontend - Web Tests
/.github/workflows/web-tests.yml @iamjoel
# Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh
web/app/components/app/configuration/ @iamjoel @zxhlyh
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
/web/app/components/workflow/ @iamjoel @zxhlyh
/web/app/components/workflow-app/ @iamjoel @zxhlyh
/web/app/components/app/configuration/ @iamjoel @zxhlyh
/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat
web/app/components/base/chat/ @iamjoel @zxhlyh
/web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion
web/app/components/share/text-generation/ @iamjoel @zxhlyh
/web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation
web/app/components/apps/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
/web/app/components/apps/ @JzoNgKVO @iamjoel
/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation
web/app/components/develop/ @JzoNgKVO @iamjoel
/web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
web/app/components/app/log/ @JzoNgKVO @iamjoel
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
/web/app/components/app/log/ @JzoNgKVO @iamjoel
/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
web/app/components/app/overview/ @JzoNgKVO @iamjoel
/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
/web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation
web/app/components/datasets/list/ @iamjoel @WTW0313
web/app/components/datasets/create/ @iamjoel @WTW0313
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
/web/app/components/datasets/list/ @iamjoel @WTW0313
/web/app/components/datasets/create/ @iamjoel @WTW0313
/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
web/app/components/rag-pipeline/ @iamjoel @WTW0313
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
/web/app/components/rag-pipeline/ @iamjoel @WTW0313
/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings
web/app/components/datasets/settings/ @iamjoel @WTW0313
/web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins
web/app/components/plugins/ @iamjoel @zhsama
/web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools
web/app/components/tools/ @iamjoel @Yessenia-d
/web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration
web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel
web/app/account/ @douxc @iamjoel
/web/app/signin/ @douxc @iamjoel
/web/app/signup/ @douxc @iamjoel
/web/app/reset-password/ @douxc @iamjoel
/web/app/install/ @douxc @iamjoel
/web/app/init/ @douxc @iamjoel
/web/app/forgot-password/ @douxc @iamjoel
/web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication
web/service/base.ts @douxc @iamjoel
/web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control
web/app/(shareLayout)/components/ @douxc @iamjoel
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
web/app/components/app/app-access-control/ @douxc @iamjoel
/web/app/(shareLayout)/components/ @douxc @iamjoel
/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
/web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page
web/app/components/explore/ @CodingOnStar @iamjoel
/web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics
web/app/components/base/ga/ @CodingOnStar @iamjoel
/web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components
web/app/components/base/ @iamjoel @zxhlyh
/web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks
web/utils/classnames.ts @iamjoel @zxhlyh
web/utils/time.ts @iamjoel @zxhlyh
web/utils/format.ts @iamjoel @zxhlyh
web/utils/clipboard.ts @iamjoel @zxhlyh
web/hooks/use-document-title.ts @iamjoel @zxhlyh
/web/utils/classnames.ts @iamjoel @zxhlyh
/web/utils/time.ts @iamjoel @zxhlyh
/web/utils/format.ts @iamjoel @zxhlyh
/web/utils/clipboard.ts @iamjoel @zxhlyh
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education
web/app/components/billing/ @iamjoel @zxhlyh
web/app/education-apply/ @iamjoel @zxhlyh
/web/app/components/billing/ @iamjoel @zxhlyh
/web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
# Docker
/docker/* @laipz8200

View File

@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -57,7 +57,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@v2.0.2
uses: hoverkraft-tech/compose-action@v2
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@ -12,12 +12,28 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Check Docker Compose inputs
id: docker-compose-changes
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v6
- uses: astral-sh/setup-uv@v7
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- run: |
cd api
@ -66,27 +82,6 @@ jobs:
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
- name: Web dependencies
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: oxlint
working-directory: ./web
run: pnpm exec oxlint --config .oxlintrc.json --fix .
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@ -90,7 +90,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*

View File

@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: "3.12"
@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: "3.12"

View File

@ -27,7 +27,7 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- uses: dorny/paths-filter@v3
id: changes
with:
@ -38,6 +38,7 @@ jobs:
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
- '.github/workflows/web-tests.yml'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'

View File

@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
uses: tj-actions/changed-files@v47
with:
files: |
api/**
@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: false
python-version: "3.12"
@ -68,15 +68,17 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
uses: tj-actions/changed-files@v47
with:
files: web/**
files: |
web/**
.github/workflows/style.yml
- name: Install pnpm
uses: pnpm/action-setup@v4
@ -85,12 +87,12 @@ jobs:
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
@ -108,50 +110,20 @@ jobs:
working-directory: ./web
run: pnpm run type-check:tsgo
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter:
name: SuperLinter
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
uses: tj-actions/changed-files@v47
with:
files: |
**.sh

View File

@ -25,12 +25,12 @@ jobs:
working-directory: sdks/nodejs-client
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: ${{ matrix.node-version }}
cache: ''

View File

@ -1,10 +1,10 @@
name: Check i18n Files and Create PR
name: Translate i18n Files Based on English
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.ts'
- 'web/i18n/en-US/*.json'
permissions:
contents: write
@ -18,7 +18,7 @@ jobs:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@ -28,13 +28,13 @@ jobs:
run: |
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .ts)
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
@ -51,11 +51,11 @@ jobs:
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
@ -65,12 +65,7 @@ jobs:
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Generate i18n type definitions
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run gen:i18n-types
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
@ -78,14 +73,13 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files and update type definitions'
title: 'chore(i18n): translate i18n files based on en-US changes'
body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
This PR was automatically created to update i18n translation files based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
**Changes included:**
- Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true

View File

@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@v2
uses: endersonmenezes/free-disk-space@v3
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -13,46 +13,356 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: web/**
- name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
uses: actions/setup-node@v6
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm test
run: pnpm test:coverage
- name: Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Vitest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v6
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error

12
.gitignore vendored
View File

@ -139,7 +139,6 @@ pyrightconfig.json
.idea/'
.DS_Store
web/.vscode/settings.json
# Intellij IDEA Files
.idea/*
@ -196,6 +195,7 @@ docker/nginx/ssl/*
!docker/nginx/ssl/.gitkeep
docker/middleware.env
docker/docker-compose.override.yaml
docker/env-backup/*
sdks/python-client/build
sdks/python-client/dist
@ -205,7 +205,6 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template
!.vscode/README.md
api/.vscode
web/.vscode
# vscode Code History Extension
.history
@ -220,15 +219,6 @@ plugins.jsonl
# mise
mise.toml
# Next.js build output
.next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant
.roo/

View File

@ -1,34 +0,0 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel",
"INFO"
],

View File

@ -116,6 +116,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
@ -127,12 +128,14 @@ TENCENT_COS_SECRET_KEY=your-secret-key
TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
HUAWEI_OBS_PATH_STYLE=false
# Baidu OBS Storage Configuration
BAIDU_OBS_BUCKET_NAME=your-bucket-name
@ -690,3 +693,7 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -84,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
default=600.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
@ -1270,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
)
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1295,6 +1310,7 @@ class FeatureConfig(
PositionConfig,
RagEtlConfig,
RepositoryConfig,
SandboxExpiredRecordsCleanConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,

View File

@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
default=None,
)
ALIYUN_CLOUDBOX_ID: str | None = Field(
description="Cloudbox id for aliyun cloudbox service",
default=None,
)

View File

@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
default=None,
)
HUAWEI_OBS_PATH_STYLE: bool = Field(
description="Flag to indicate whether to use path-style URLs for OBS requests",
default=False,
)

View File

@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings):
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
default=None,
)
TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
description="Tencent Cloud COS custom domain setting",
default=None,
)

View File

@ -0,0 +1,57 @@
import os
from email.message import Message
from urllib.parse import quote
from flask import Response
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
HTML_EXTENSIONS = frozenset({"html", "htm"})
def _normalize_mime_type(mime_type: str | None) -> str:
if not mime_type:
return ""
message = Message()
message["Content-Type"] = mime_type
return message.get_content_type().strip().lower()
def _is_html_extension(extension: str | None) -> bool:
if not extension:
return False
return extension.lstrip(".").lower() in HTML_EXTENSIONS
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
normalized_mime_type = _normalize_mime_type(mime_type)
if normalized_mime_type in HTML_MIME_TYPES:
return True
if _is_html_extension(extension):
return True
if filename:
return _is_html_extension(os.path.splitext(filename)[1])
return False
def enforce_download_for_html(
response: Response,
*,
mime_type: str | None,
filename: str | None,
extension: str | None = None,
) -> bool:
if not is_html_content(mime_type, filename, extension):
return False
if filename:
encoded_filename = quote(filename)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
else:
response.headers["Content-Disposition"] = "attachment"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["X-Content-Type-Options"] = "nosniff"
return True

View File

@ -7,9 +7,9 @@ from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.helper import EmailStr, timezone
from models import AccountStatus
from services.account_service import AccountService, RegisterService
from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -93,7 +93,6 @@ class ActivateApi(Resource):
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.Raw(description="Login token data"),
},
),
)
@ -117,6 +116,4 @@ class ActivateApi(Resource):
account.initialized_at = naive_utc_now()
db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()}
return {"result": "success"}

View File

@ -1,8 +1,9 @@
import base64
from typing import Literal
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: str = Field(..., description="Subscription plan")
interval: str = Field(..., description="Billing interval")
@field_validator("plan")
@classmethod
def validate_plan(cls, value: str) -> str:
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
raise ValueError("Invalid plan")
return value
@field_validator("interval")
@classmethod
def validate_interval(cls, value: str) -> str:
if value not in {"month", "year"}:
raise ValueError("Invalid interval")
return value
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
interval: Literal["month", "year"] = Field(..., description="Billing interval")
class PartnerTenantsPayload(BaseModel):

View File

@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None

View File

@ -572,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],

View File

@ -40,7 +40,7 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
class CompletionMessageExplorePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
# define completion api for user
@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming"

View File

@ -1,5 +1,4 @@
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import marshal_with
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
from models.model import AppMode
@ -24,7 +24,7 @@ from .. import console_ns
class ConversationListQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None

View File

@ -2,7 +2,8 @@ import logging
from typing import Any
from flask import request
from flask_restx import Resource, inputs, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@ -18,6 +19,15 @@ from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
class InstalledAppCreatePayload(BaseModel):
app_id: str
class InstalledAppUpdatePayload(BaseModel):
is_pinned: bool | None = None
logger = logging.getLogger(__name__)
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
def post(self):
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
if recommended_app is None:
raise NotFound("App not found")
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == args["app_id"]).first()
app = db.session.query(App).where(App.id == payload.app_id).first()
if app is None:
raise NotFound("App not found")
raise NotFound("App entity not found")
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = (
db.session.query(InstalledApp)
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.first()
)
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
recommended_app.install_count += 1
new_installed_app = InstalledApp(
app_id=args["app_id"],
app_id=payload.app_id,
tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id,
is_pinned=False,
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app):
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
args = parser.parse_args()
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
commit_args = False
if "is_pinned" in args:
installed_app.is_pinned = args["is_pinned"]
if payload.is_pinned is not None:
installed_app.is_pinned = payload.is_pinned
commit_args = True
if commit_args:

View File

@ -1,6 +1,5 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import marshal_with
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
@ -10,19 +8,19 @@ from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUID
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,14 +1,32 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants import HIDDEN_VALUE
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
class CodeBasedExtensionQuery(BaseModel):
module: str
class APIBasedExtensionPayload(BaseModel):
name: str = Field(description="Extension name")
api_endpoint: str = Field(description="API endpoint URL")
api_key: str = Field(description="API key for authentication")
register_schema_models(console_ns, APIBasedExtensionPayload)
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m
class CodeBasedExtensionAPI(Resource):
@console_ns.doc("get_code_based_extension")
@console_ns.doc(description="Get code-based extension data by module name")
@console_ns.expect(
console_ns.parser().add_argument(
"module", type=str, required=True, location="args", help="Extension module name"
)
)
@console_ns.doc(params={"module": "Extension module name"})
@console_ns.response(
200,
"Success",
@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
args = parser.parse_args()
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
@console_ns.route("/api-based-extension")
@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource):
@console_ns.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension")
@console_ns.expect(
console_ns.model(
"CreateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self):
args = console_ns.payload
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_tenant_id,
name=args["name"],
api_endpoint=args["api_endpoint"],
api_key=args["api_key"],
name=payload.name,
api_endpoint=payload.api_endpoint,
api_key=payload.api_key,
)
return APIBasedExtensionService.save(extension_data)
@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.expect(
console_ns.model(
"UpdateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@setup_required
@login_required
@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
args = console_ns.payload
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"]
extension_data_from_db.name = payload.name
extension_data_from_db.api_endpoint = payload.api_endpoint
if args["api_key"] != HIDDEN_VALUE:
extension_data_from_db.api_key = args["api_key"]
if payload.api_key != HIDDEN_VALUE:
extension_data_from_db.api_key = payload.api_key
return APIBasedExtensionService.save(extension_data_from_db)

View File

@ -1,31 +1,40 @@
from typing import Literal
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
from models.model import Tag
from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 50:
raise ValueError("Name must be between 1 to 50 characters.")
return name
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
parser_tags = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
)
@ -43,7 +52,7 @@ class TagListApi(Resource):
return tags, 200
@console_ns.expect(parser_tags)
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -53,22 +62,17 @@ class TagListApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_tags.parse_args()
tag = TagService.save_tags(args)
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump())
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
parser_tag_id = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@console_ns.expect(parser_tag_id)
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_tag_id.parse_args()
tag = TagService.update_tags(args, tag_id)
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource):
return 204
parser_create = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@console_ns.expect(parser_create)
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_create.parse_args()
TagService.save_tag_binding(args)
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
parser_remove = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@console_ns.expect(parser_remove)
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_remove.parse_args()
TagService.delete_tag_binding(args)
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump())
return {"result": "success"}, 200

View File

@ -1,6 +1,8 @@
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@ -10,10 +12,20 @@ from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
class LoadBalancingCredentialPayload(BaseModel):
model: str
model_type: ModelType
credentials: dict[str, object]
register_schema_models(console_ns, LoadBalancingCredentialPayload)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService()
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
model=payload.model,
model_type=payload.model_type,
credentials=payload.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService()
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
model=payload.model,
model_type=payload.model_type,
credentials=payload.credentials,
config_id=config_id,
)
except CredentialsValidateFailedError as ex:

View File

@ -1,5 +1,6 @@
import io
from typing import Literal
from collections.abc import Mapping
from typing import Any, Literal
from flask import request, send_file
from flask_restx import Resource
@ -141,6 +142,15 @@ class ParserDynamicOptions(BaseModel):
provider_type: Literal["tool", "trigger"]
class ParserDynamicOptionsWithCredentials(BaseModel):
plugin_id: str
provider: str
action: str
parameter: str
credential_id: str
credentials: Mapping[str, Any]
class PluginPermissionSettingsPayload(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
@ -183,6 +193,7 @@ reg(ParserGithubUpgrade)
reg(ParserUninstall)
reg(ParserPermissionChange)
reg(ParserDynamicOptions)
reg(ParserDynamicOptionsWithCredentials)
reg(ParserPreferencesChange)
reg(ParserExcludePlugin)
reg(ParserReadme)
@ -657,6 +668,37 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options-with-credentials")
class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
@console_ns.expect(console_ns.models[ParserDynamicOptionsWithCredentials.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
"""Fetch dynamic options using credentials directly (for edit mode)."""
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
try:
options = PluginParameterService.get_dynamic_select_options_with_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args.plugin_id,
provider=args.provider,
action=args.action,
parameter=args.parameter,
credential_id=args.credential_id,
credentials=args.credentials,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])

View File

@ -1,4 +1,5 @@
import io
import logging
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
@ -17,6 +18,7 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
@ -39,6 +41,8 @@ from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
logger = logging.getLogger(__name__)
def is_valid_url(url: str) -> bool:
if not url:
@ -944,8 +948,8 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
with Session(db.engine) as session, session.begin():
# 1) Create provider in a short transaction (no network I/O inside)
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
@ -960,7 +964,29 @@ class ToolProviderMCPApi(Resource):
configuration=configuration,
authentication=authentication,
)
return jsonable_encoder(result)
# 2) Try to fetch tools immediately after creation so they appear without a second save.
# Perform network I/O outside any DB session to avoid holding locks.
try:
reconnect = MCPToolManageService.reconnect_with_url(
server_url=args["server_url"],
headers=args.get("headers") or {},
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
)
# Update just-created provider with authed/tools in a new short transaction
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
db_provider.authed = reconnect.authed
db_provider.tools = reconnect.tools
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
except Exception:
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
return jsonable_encoder(result)
@console_ns.expect(parser_mcp_put)
@setup_required
@ -972,17 +998,23 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=args["provider_id"]
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
# This prevents holding database locks during potentially slow network operations
validation_result = MCPToolManageService.validate_server_url_standalone(
tenant_id=current_tenant_id,
new_server_url=args["server_url"],
validation_data=validation_data,
)
# Step 2: Perform database update in a transaction
# Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
@ -999,7 +1031,8 @@ class ToolProviderMCPApi(Resource):
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
return {"result": "success"}
@console_ns.expect(parser_mcp_delete)
@setup_required
@ -1012,7 +1045,8 @@ class ToolProviderMCPApi(Resource):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
return {"result": "success"}
parser_auth = (

View File

@ -1,11 +1,15 @@
import logging
from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -32,6 +36,32 @@ from ..wraps import (
logger = logging.getLogger(__name__)
class TriggerSubscriptionUpdateRequest(BaseModel):
"""Request payload for updating a trigger subscription"""
name: str | None = Field(default=None, description="The name for the subscription")
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
console_ns.schema_model(
TriggerSubscriptionUpdateRequest.__name__,
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
console_ns.schema_model(
TriggerSubscriptionVerifyRequest.__name__,
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
class TriggerProviderIconApi(Resource):
@setup_required
@ -155,16 +185,16 @@ parser_api = (
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -289,6 +319,83 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
raise ValueError(str(e)) from e
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
)
class TriggerSubscriptionUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, subscription_id: str):
"""Update a subscription instance"""
user = current_user
assert user.current_tenant_id is not None
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise NotFoundError(f"Subscription {subscription_id} not found")
provider_id = TriggerProviderID(subscription.provider_id)
try:
# rename only
if (
args.name is not None
and args.credentials is None
and args.parameters is None
and args.properties is None
):
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
)
return 200
# rebuild for create automatically by the provider
match subscription.credential_type:
case CredentialType.UNAUTHORIZED:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
properties=args.properties,
)
return 200
case CredentialType.API_KEY | CredentialType.OAUTH2:
if args.credentials:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in args.credentials.items()
}
else:
new_credentials = subscription.credentials
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=args.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=new_credentials,
parameters=args.parameters or subscription.parameters,
)
return 200
case _:
raise BadRequest("Invalid credential type")
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error updating subscription", exc_info=e)
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
@ -576,3 +683,38 @@ class TriggerOAuthClientManageApi(Resource):
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
)
class TriggerSubscriptionVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_id):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
console_ns.payload
)
try:
result = TriggerProviderService.verify_subscription_credentials(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_id=subscription_id,
credentials=verify_request.credentials,
)
return result
except ValueError as e:
logger.warning("Credential verification failed", exc_info=e)
raise BadRequest(str(e)) from e
except Exception as e:
logger.exception("Error verifying subscription credentials", exc_info=e)
raise BadRequest(str(e)) from e

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import UnsupportedFileTypeError
from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from extensions.ext_database import db
from services.account_service import TenantService
@ -138,6 +139,13 @@ class FilePreviewApi(Resource):
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,
mime_type=upload_file.mime_type,
filename=upload_file.name,
extension=upload_file.extension,
)
return response

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
@ -78,4 +79,11 @@ class ToolFileApi(Resource):
encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
enforce_download_for_html(
response,
mime_type=tool_file.mimetype,
filename=tool_file.name,
extension=extension,
)
return response

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255
)
@field_validator("variable_name", mode="before")
@classmethod
def validate_variable_name(cls, v: str | None) -> str | None:
"""
Validate variable_name to prevent injection attacks.
"""
if v is None:
return v
# Only allow safe characters: alphanumeric, underscore, hyphen, period
if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
raise ValueError(
"Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
)
# Prevent SQL injection patterns
dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
for pattern in dangerous_patterns:
if pattern in v.lower():
raise ValueError(f"Variable name contains invalid characters: {pattern}")
return v
class ConversationVariableUpdatePayload(BaseModel):
@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource):
try:
return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, query_args.limit, last_id
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -5,6 +5,7 @@ from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.file_response import enforce_download_for_html
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@ -183,6 +184,13 @@ class FilePreviewApi(Resource):
# Override content-type for downloads to force download
response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,
mime_type=upload_file.mime_type,
filename=upload_file.name,
extension=upload_file.extension,
)
# Add caching headers for performance
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour

View File

@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
@ -49,7 +48,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@validate_dataset_token
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _, dataset_id):
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token
def patch(self, _, dataset_id):
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
@edit_permission_required
def delete(self, _, dataset_id):
def delete(self, _):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")

View File

@ -1,14 +1,13 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized
from constants import HEADER_NAME_APP_CODE
from controllers.common import fields
from controllers.web import web_ns
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from controllers.common.schema import register_schema_models
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService
from libs.token import extract_webapp_passport
@ -18,9 +17,23 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
from . import web_ns
from .error import AppUnavailableError
from .wraps import WebApiResource
logger = logging.getLogger(__name__)
class AppAccessModeQuery(BaseModel):
model_config = ConfigDict(populate_by_name=True)
app_id: str | None = Field(default=None, alias="appId", description="Application ID")
app_code: str | None = Field(default=None, alias="appCode", description="Application code")
register_schema_models(web_ns, AppAccessModeQuery)
@web_ns.route("/parameters")
class AppParameterApi(WebApiResource):
"""Resource for app variables."""
@ -96,21 +109,16 @@ class AppAccessMode(Resource):
}
)
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("appId", type=str, required=False, location="args")
.add_argument("appCode", type=str, required=False, location="args")
)
args = parser.parse_args()
raw_args = request.args.to_dict()
args = AppAccessModeQuery.model_validate(raw_args)
features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
return {"accessMode": "public"}
app_id = args.get("appId")
if args.get("appCode"):
app_code = args["appCode"]
app_id = AppService.get_app_id_by_code(app_code)
app_id = args.app_id
if args.app_code:
app_id = AppService.get_app_id_by_code(args.app_code)
if not app_id:
raise ValueError("appId or appCode must be provided")

View File

@ -1,7 +1,8 @@
import logging
from flask import request
from flask_restx import fields, marshal_with, reqparse
from flask_restx import fields, marshal_with
from pydantic import BaseModel, field_validator
from werkzeug.exceptions import InternalServerError
import services
@ -20,6 +21,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App
from services.audio_service import AudioService
from services.errors.audio import (
@ -29,6 +31,25 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError,
)
from ..common.schema import register_schema_models
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, TextToAudioPayload)
logger = logging.getLogger(__name__)
@ -88,6 +109,7 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(
@ -102,18 +124,11 @@ class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
"""Convert text to audio"""
try:
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = payload.message_id
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)

View File

@ -1,9 +1,11 @@
import logging
from typing import Any, Literal
from flask_restx import reqparse
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the completion")
query: str = Field(default="", description="Query text for completion")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
retriever_from: str = Field(default="web_app", description="Source of retriever")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the chat")
query: str = Field(description="User query/message")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
retriever_from: str = Field(default="web_app", description="Source of retriever")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user
@web_ns.route("/completion-messages")
class CompletionApi(WebApiResource):
@web_ns.doc("Create Completion Message")
@web_ns.doc(description="Create a completion message for text generation applications.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
"query": {"description": "Query text for completion", "type": "string", "required": False},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@ -67,18 +88,10 @@ class CompletionApi(WebApiResource):
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:
@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
@web_ns.doc("Create Chat Message")
@web_ns.doc(description="Create a chat message for conversational applications.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
"query": {"description": "User query/message", "type": "string", "required": True},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@ -173,20 +171,10 @@ class ChatApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:

View File

@ -2,10 +2,12 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
@ -18,14 +20,40 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr
language: str | None = None
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
@web_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordSendPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@ -40,35 +68,31 @@ class ForgotPasswordSendEmailApi(Resource):
}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if payload.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
token = None
if account is None:
raise AuthenticationFailedError()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
return {"result": "success", "data": token}
@web_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordCheckPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@ -78,45 +102,40 @@ class ForgotPasswordCheckApi(Resource):
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
user_email = args["email"]
user_email = payload.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
token_data = AccountService.get_reset_password_data(payload.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(payload.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(payload.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
user_email, code=payload.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
AccountService.reset_forgot_password_error_rate_limit(payload.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@web_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordResetPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@ -131,20 +150,14 @@ class ForgotPasswordResetApi(Resource):
}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {})
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
if payload.new_password != payload.password_confirm:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
reset_data = AccountService.get_reset_password_data(payload.token)
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
@ -152,11 +165,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(payload.token)
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt)
password_hashed = hash_password(payload.new_password, salt)
email = reset_data.get("email", "")
@ -170,7 +183,7 @@ class ForgotPasswordResetApi(Resource):
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
def _update_existing_account(self, account: Account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()

View File

@ -1,9 +1,12 @@
import logging
from typing import Literal
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
AppMoreLikeThisDisabledError,
@ -38,6 +41,33 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: str = Field(description="Conversation UUID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",
)
register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery)
@web_ns.route("/messages")
class MessageListApi(WebApiResource):
message_fields = {
@ -69,7 +99,11 @@ class MessageListApi(WebApiResource):
@web_ns.doc(
params={
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
"first_id": {"description": "First message ID for pagination", "type": "string", "required": False},
"first_id": {
"description": "First message ID for pagination",
"type": "string",
"required": False,
},
"limit": {
"description": "Number of messages to return (1-100)",
"type": "integer",
@ -94,17 +128,12 @@ class MessageListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
raw_args = request.args.to_dict()
query = MessageListQuery.model_validate(raw_args)
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, end_user, query.conversation_id, query.first_id, query.limit
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -129,7 +158,7 @@ class MessageFeedbackApi(WebApiResource):
"enum": ["like", "dislike"],
"required": False,
},
"content": {"description": "Feedback content/comment", "type": "string", "required": False},
"content": {"description": "Feedback content", "type": "string", "required": False},
}
)
@web_ns.doc(
@ -146,20 +175,15 @@ class MessageFeedbackApi(WebApiResource):
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json", default=None)
)
args = parser.parse_args()
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
rating=args.get("rating"),
content=args.get("content"),
rating=payload.rating,
content=payload.content,
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -171,17 +195,7 @@ class MessageFeedbackApi(WebApiResource):
class MessageMoreLikeThisApi(WebApiResource):
@web_ns.doc("Generate More Like This")
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
@web_ns.doc(
params={
"message_id": {"description": "Message UUID", "type": "string", "required": True},
"response_mode": {
"description": "Response mode",
"type": "string",
"enum": ["blocking", "streaming"],
"required": True,
},
}
)
@web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__])
@web_ns.doc(
responses={
200: "Success",
@ -198,12 +212,10 @@ class MessageMoreLikeThisApi(WebApiResource):
message_id = str(message_id)
parser = reqparse.RequestParser().add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
raw_args = request.args.to_dict()
query = MessageMoreLikeThisQuery.model_validate(raw_args)
streaming = args["response_mode"] == "streaming"
streaming = query.response_mode == "streaming"
try:
response = AppGenerateService.generate_more_like_this(

View File

@ -1,7 +1,8 @@
import urllib.parse
import httpx
from flask_restx import marshal_with, reqparse
from flask_restx import marshal_with
from pydantic import BaseModel, Field, HttpUrl
import services
from controllers.common import helpers
@ -10,14 +11,23 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
from services.file_service import FileService
from ..common.schema import register_schema_models
from . import web_ns
from .wraps import WebApiResource
class RemoteFileUploadPayload(BaseModel):
url: HttpUrl = Field(description="Remote file URL")
register_schema_models(web_ns, RemoteFileUploadPayload)
@web_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(WebApiResource):
@ -97,10 +107,8 @@ class RemoteFileUploadApi(WebApiResource):
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
url = args["url"]
payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {})
url = str(payload.url)
try:
resp = ssrf_proxy.head(url=url)

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
json_schema: str | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
def validate_json_schema(cls, schema: str | None) -> str | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@ -104,8 +105,9 @@ class BaseAppGenerator:
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
and not variable_entity.required
):
# Treat empty string (frontend default) or empty list as unset
if not value and isinstance(value, (str, list)):
# Treat empty string (frontend default) as unset
# For FILE_LIST, allow empty list [] to pass through
if isinstance(value, str) and not value:
return None
if variable_entity.type in {
@ -175,6 +177,13 @@ class BaseAppGenerator:
value = True
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@ -90,6 +90,7 @@ class AppQueueManager:
"""
self._clear_task_belong_cache()
self._q.put(None)
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
def _clear_task_belong_cache(self) -> None:
"""

View File

@ -345,9 +345,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
event_type=event_type,
)
else:
yield self._agent_message_to_stream_response(

View File

@ -5,7 +5,7 @@ from threading import Thread
from typing import Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from configs import dify_config
@ -54,6 +54,20 @@ class MessageCycleManager:
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state
self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent:
if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE
with Session(db.engine, expire_on_commit=False) as session:
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
if has_file:
self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE
return StreamEvent.MESSAGE
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
@ -224,6 +238,7 @@ class MessageCycleManager:
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
event_type: StreamEvent | None = None,
) -> MessageStreamResponse:
"""
Message to stream response.
@ -238,22 +253,18 @@ class MessageCycleManager:
:param tool_error: error message if tool failed
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
event=event_type,
chunk_type=chunk_type,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_files=tool_files,
tool_error=tool_error,
event=event_type or StreamEvent.MESSAGE,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@ -1,9 +1,14 @@
from collections.abc import Mapping
from textwrap import dedent
from typing import Any
from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer):
# Use separate placeholder for base64-encoded template to avoid confusion
_template_b64_placeholder: str = "{{template_b64}}"
@classmethod
def transform_response(cls, response: str):
"""
@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer):
"""
return {"result": cls.extract_result_str_from_response(response)}
@classmethod
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
"""
Override base class to use base64 encoding for template code.
This prevents issues with special characters (quotes, newlines) in templates
breaking the generated Python script. Fixes #26818.
"""
script = cls.get_runner_script()
# Encode template as base64 to safely embed any content including quotes
code_b64 = cls.serialize_code(code)
script = script.replace(cls._template_b64_placeholder, code_b64)
inputs_str = cls.serialize_inputs(inputs)
script = script.replace(cls._inputs_placeholder, inputs_str)
return script
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
# declare main function
def main(**inputs):
import jinja2
template = jinja2.Template('''{cls._code_placeholder}''')
return template.render(**inputs)
import jinja2
import json
from base64 import b64decode
# declare main function
def main(**inputs):
# Decode base64-encoded template to handle special characters safely
template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
template = jinja2.Template(template_code)
return template.render(**inputs)
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))

View File

@ -13,6 +13,15 @@ class TemplateTransformer(ABC):
_inputs_placeholder: str = "{{inputs}}"
_result_tag: str = "<<RESULT>>"
@classmethod
def serialize_code(cls, code: str) -> str:
"""
Serialize template code to base64 to safely embed in generated script.
This prevents issues with special characters like quotes breaking the script.
"""
code_bytes = code.encode("utf-8")
return b64encode(code_bytes).decode("utf-8")
@classmethod
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
"""

View File

@ -72,6 +72,22 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
)
def _get_user_provided_host_header(headers: dict | None) -> str | None:
"""
Extract the user-provided Host header from the headers dict.
This is needed because when using a forward proxy, httpx may override the Host header.
We preserve the user's explicit Host header to support virtual hosting and other use cases.
"""
if not headers:
return None
# Case-insensitive lookup for Host header
for key, value in headers.items():
if key.lower() == "host":
return value
return None
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@ -90,10 +106,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host
kwargs["headers"] = headers
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection

View File

@ -1,56 +0,0 @@
import json
import logging
from typing import Any
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
logger = logging.getLogger(__name__)
class ToolProviderListCache:
"""Cache for tool provider lists"""
CACHE_TTL = 300 # 5 minutes
@staticmethod
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
"""Generate cache key for tool providers list"""
type_filter = typ or "all"
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
@staticmethod
@redis_fallback(default_return=None)
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
"""Get cached tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
cached_data = redis_client.get(cache_key)
if cached_data:
try:
return json.loads(cached_data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Failed to decode cached tool providers data")
return None
return None
@staticmethod
@redis_fallback()
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
"""Cache tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
@staticmethod
@redis_fallback()
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""Invalidate cache for tool providers"""
if typ:
# Invalidate specific type cache
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
keys = list(redis_client.scan_iter(pattern))
if keys:
redis_client.delete(*keys)

View File

@ -396,7 +396,7 @@ class IndexingRunner:
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],

View File

@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL.
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
Priority order:
1. URL from WWW-Authenticate header (if provided)
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
"""
urls = []
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
path = parsed.path.rstrip("/")
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
if path:
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
if path_url not in urls:
urls.append(path_url)
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
root_url = f"{base_url}/.well-known/oauth-protected-resource"
if root_url not in urls:
urls.append(root_url)
return urls
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
path = parsed.path.rstrip("/")
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
urls.append(f"{base}/.well-known/oauth-authorization-server")
# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OpenID Connect Discovery at root
urls.append(f"{base}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
# OpenID Connect Discovery with path insertion
urls.append(f"{base}/.well-known/openid-configuration{path}")
# OpenID Connect Discovery path appending
urls.append(f"{base}{path}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata with path insertion
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
return urls

View File

@ -61,6 +61,7 @@ class SSETransport:
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None
self.event_source: EventSource | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin.
@ -237,6 +238,9 @@ class SSETransport:
write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue()
# Store event_source for graceful shutdown
self.event_source = event_source
# Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
@ -296,6 +300,13 @@ def sse_client(
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Close the SSE connection to unblock the reader thread
if transport.event_source is not None:
try:
transport.event_source.response.close()
except RuntimeError:
pass
# Clean up queues
if read_queue:
read_queue.put(None)

View File

@ -8,6 +8,7 @@ and session management.
import logging
import queue
import threading
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@ -103,6 +104,9 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON,
**self.headers,
}
self.stop_event = threading.Event()
self._active_responses: list[httpx.Response] = []
self._lock = threading.Lock()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
@ -111,6 +115,30 @@ class StreamableHTTPTransport:
headers[MCP_SESSION_ID] = self.session_id
return headers
def _register_response(self, response: httpx.Response):
"""Register a response for cleanup on shutdown."""
with self._lock:
self._active_responses.append(response)
def _unregister_response(self, response: httpx.Response):
"""Unregister a response after it's closed."""
with self._lock:
try:
self._active_responses.remove(response)
except ValueError as e:
logger.debug("Ignoring error during response unregister: %s", e)
def close_active_responses(self):
"""Close all active SSE connections to unblock threads."""
with self._lock:
responses_to_close = list(self._active_responses)
self._active_responses.clear()
for response in responses_to_close:
try:
response.close()
except RuntimeError as e:
logger.debug("Ignoring error during active response close: %s", e)
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
@ -195,11 +223,21 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
for sse in event_source.iter_sse():
self._handle_sse_event(sse, server_to_client_queue)
# Register response for cleanup
self._register_response(event_source.response)
try:
for sse in event_source.iter_sse():
if self.stop_event.is_set():
logger.debug("GET stream received stop signal")
break
self._handle_sse_event(sse, server_to_client_queue)
finally:
self._unregister_response(event_source.response)
except Exception as exc:
logger.debug("GET stream error (non-fatal): %s", exc)
if not self.stop_event.is_set():
logger.debug("GET stream error (non-fatal): %s", exc)
def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE."""
@ -224,15 +262,24 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse():
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
# Register response for cleanup
self._register_response(event_source.response)
try:
for sse in event_source.iter_sse():
if self.stop_event.is_set():
logger.debug("Resumption stream received stop signal")
break
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
finally:
self._unregister_response(event_source.response)
def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing."""
@ -266,17 +313,20 @@ class StreamableHTTPTransport:
if is_initialization:
self._maybe_extract_session_id_from_response(response)
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
def _handle_json_response(
self,
@ -295,17 +345,27 @@ class StreamableHTTPTransport:
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server."""
try:
# Register response for cleanup
self._register_response(response)
event_source = EventSource(response)
for sse in event_source.iter_sse():
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
try:
for sse in event_source.iter_sse():
if self.stop_event.is_set():
logger.debug("SSE response stream received stop signal")
break
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
finally:
self._unregister_response(response)
except Exception as e:
ctx.server_to_client_queue.put(e)
if not self.stop_event.is_set():
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type(
self,
@ -345,6 +405,11 @@ class StreamableHTTPTransport:
"""
while True:
try:
# Check if we should stop
if self.stop_event.is_set():
logger.debug("Post writer received stop signal")
break
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None:
@ -381,7 +446,8 @@ class StreamableHTTPTransport:
except queue.Empty:
continue
except Exception as exc:
server_to_client_queue.put(exc)
if not self.stop_event.is_set():
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request."""
@ -465,6 +531,12 @@ def streamablehttp_client(
transport.get_session_id,
)
finally:
# Set stop event to signal all threads to stop
transport.stop_event.set()
# Close all active SSE connections to unblock threads
transport.close_active_responses()
if transport.session_id and terminate_on_close:
transport.terminate_session(client)

View File

@ -59,7 +59,7 @@ class MCPClient:
try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse")
except MCPConnectionError:
except (MCPConnectionError, ValueError):
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")

View File

@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model
- Model provider display
![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png)
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
- Selectable model list display
![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png)
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png)
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
- Provider/model credential authentication
![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png)
![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png)
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
## Structure
![](./docs/en_US/images/index/image-20231210165243632.png)
Model Runtime is divided into three layers:
- The outermost layer is the factory method
@ -60,9 +46,6 @@ Model Runtime is divided into three layers:
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
## Next Steps
## Documentation
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
- Implement interface methods: [Link](./docs/en_US/interfaces.md)
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).

View File

@ -18,34 +18,20 @@
- 模型供应商展示
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
- 可选择的模型列表展示
![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png)
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png)
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
- 供应商/模型凭据鉴权
![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png)
![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png)
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO上图 2 为模型凭据 DEMO。
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
## 结构
![](./docs/zh_Hans/images/index/image-20231210165243632.png)
Model Runtime 分三层:
- 最外层为工厂方法
@ -59,8 +45,7 @@ Model Runtime 分三层:
对于供应商/模型凭据,有两种情况
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
![Alt text](docs/zh_Hans/images/index/image.png)
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
@ -74,20 +59,6 @@ Model Runtime 分三层:
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
## 下一步
## 文档
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
当添加后,这里将会出现一个新的供应商
![Alt text](docs/zh_Hans/images/index/image-1.png)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。

View File

@ -54,7 +54,7 @@ def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_
generate dotted_order for langsmith
"""
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f") + "Z"
current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None:

View File

@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None
required: bool = False
default: Union[float, int, str, bool] | None = None
default: Union[float, int, str, bool, list, dict] | None = None
min: Union[float, int] | None = None
max: Union[float, int] | None = None
precision: int | None = None

View File

@ -39,7 +39,7 @@ from core.trigger.errors import (
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None,
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:

View File

@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
segment_query_stmt = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
)
if document_ids_filter:
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
segments = db.session.execute(segment_query_stmt).scalars().all()
segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
segment = segment_map.get(chunk_index)
if segment:
documents.append(

View File

@ -1,4 +1,5 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@ -7,12 +8,13 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -35,6 +37,8 @@ default_retrieval_model = {
"score_threshold_enabled": False,
}
logger = logging.getLogger(__name__)
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@ -105,7 +109,12 @@ class RetrievalService:
)
)
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
if futures:
for future in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()
break
if exceptions:
raise ValueError(";\n".join(exceptions))
@ -138,37 +147,47 @@ class RetrievalService:
@classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
"""Deduplicate documents in O(n) while preserving first-seen order.
Rules:
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
- For non-dify documents (or dify without doc_id): deduplicate by content key
(provider, page_content), keeping the first occurrence.
"""
if not documents:
return documents
unique_documents = []
seen_doc_ids = set()
# Map of dedup key -> chosen Document
chosen: dict[tuple, Document] = {}
# Preserve the order of first appearance of each dedup key
order: list[tuple] = []
for document in documents:
# For dify provider documents, use doc_id for deduplication
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
doc_id = document.metadata["doc_id"]
if doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
unique_documents.append(document)
# If duplicate, keep the one with higher score
elif "score" in document.metadata:
# Find existing document with same doc_id and compare scores
for i, existing_doc in enumerate(unique_documents):
if (
existing_doc.metadata
and existing_doc.metadata.get("doc_id") == doc_id
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
):
unique_documents[i] = document
break
for doc in documents:
is_dify = doc.provider == "dify"
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
if is_dify and doc_id:
key = ("dify", doc_id)
if key not in chosen:
chosen[key] = doc
order.append(key)
else:
# Only replace if the new one has a score and it's strictly higher
if "score" in doc.metadata:
new_score = float(doc.metadata.get("score", 0.0))
old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
if new_score > old_score:
chosen[key] = doc
else:
# For non-dify documents, use content-based deduplication
if document not in unique_documents:
unique_documents.append(document)
# Content-based dedup for non-dify or dify without doc_id
content_key = (doc.provider or "dify", doc.page_content)
if content_key not in chosen:
chosen[content_key] = doc
order.append(content_key)
# If duplicate content appears, we keep the first occurrence (no score comparison)
return unique_documents
return [chosen[k] for k in order]
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
@ -199,6 +218,7 @@ class RetrievalService:
)
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@ -292,6 +312,7 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@ -340,6 +361,7 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@staticmethod
@ -370,171 +392,176 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
segment_file_map = {}
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
valid_dataset_documents = {}
image_doc_ids: list[Any] = []
child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
valid_dataset_documents[document_id] = dataset_document
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if not segment_id:
continue
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.first()
)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
if attachment_info:
if segment.id in segment_file_map:
segment_file_map[segment.id].append(attachment_info)
else:
segment_file_map[segment.id] = [attachment_info]
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
# Handle normal documents
segment = None
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = session.scalar(document_segment_stmt)
child_index_node_ids.append(doc_id)
else:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
index_node_ids.append(doc_id)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
image_doc_ids = [i for i in image_doc_ids if i]
child_index_node_ids = [i for i in child_index_node_ids if i]
index_node_ids = [i for i in index_node_ids if i]
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[dict[str, Any]]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
for attachment in attachments:
segment_ids.append(attachment["segment_id"])
if attachment["segment_id"] in attachment_map:
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
else:
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
if attachment["segment_id"] in doc_segment_map:
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
else:
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
segment_ids.append(i.segment_id)
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)
else:
child_chunk_map[i.segment_id] = [i]
if i.segment_id in doc_segment_map:
doc_segment_map[i.segment_id].append(i.index_node_id)
else:
doc_segment_map[i.segment_id] = [i.index_node_id]
if index_node_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id.in_(index_node_ids),
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id.in_(segment_ids),
)
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
if index_node_segments:
segments.extend(index_node_segments)
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunks or attachment_infos:
child_chunk_details = []
max_score = 0.0
for child_chunk in child_chunks:
document = doc_to_document_map[child_chunk.index_node_id]
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0) if document else 0.0,
}
child_chunk_details.append(child_chunk_detail)
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
for attachment_info in attachment_infos:
file_document = doc_to_document_map[attachment_info["id"]]
max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
map_detail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if attachment_info:
attachment_infos = segment_file_map.get(segment.id, [])
if attachment_info not in attachment_infos:
attachment_infos.append(attachment_info)
segment_file_map[segment.id] = attachment_infos
segment_child_map[segment.id] = map_detail
record: dict[str, Any] = {
"segment": segment,
}
records.append(record)
else:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
max_score = 0.0
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
for attachment_info in attachment_infos:
file_doc = doc_to_document_map.get(attachment_info["id"])
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
"segment": segment,
"score": max_score,
}
records.append(record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
result = []
result: list[RetrievalSegments] = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None
raw_child_chunks = record.get("child_chunks")
child_chunks_list: list[RetrievalChildChunk] | None = None
if isinstance(raw_child_chunks, list):
# Sort by score descending
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
child_chunks_list = [
RetrievalChildChunk(
id=chunk["id"],
content=chunk["content"],
score=chunk.get("score", 0.0),
position=chunk["position"],
)
for chunk in sorted_chunks
]
# Extract files, ensuring it's a list or None
files = record.get("files")
@ -551,11 +578,11 @@ class RetrievalService:
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files
segment=segment, child_chunks=child_chunks_list, score=score, files=files
)
result.append(retrieval_segment)
return result
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
except Exception as e:
db.session.rollback()
raise e
@ -565,6 +592,8 @@ class RetrievalService:
flask_app: Flask,
retrieval_method: RetrievalMethod,
dataset: Dataset,
all_documents: list[Document],
exceptions: list[str],
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
@ -573,8 +602,6 @@ class RetrievalService:
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
):
if not query and not attachment_id:
return
@ -647,7 +674,14 @@ class RetrievalService:
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
# Use as_completed for early error propagation - cancel remaining futures on first error
if futures:
for future in concurrent.futures.as_completed(futures, timeout=300):
if future.exception():
# Cancel remaining futures to avoid unnecessary waiting
for f in futures:
f.cancel()
break
if exceptions:
raise ValueError(";\n".join(exceptions))
@ -696,3 +730,37 @@ class RetrievalService:
}
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
if attachment_binding:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"segment_id": attachment_binding.segment_id,
}
)
return attachment_infos

View File

@ -289,7 +289,8 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名nt: 机构名
# `nr`: Person, `ns`: Location, `nt`: Organization
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
current_entity += word
else:
if current_entity:

View File

@ -255,7 +255,10 @@ class PGVector(BaseVector):
return
with self._get_cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
if not cur.fetchone():
cur.execute("CREATE EXTENSION vector")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing

View File

@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# Vastbase 支持的向量维度取值范围为 [1,16000]
# Vastbase supports vector dimensions in the range [1, 16,000]
if dimension <= 16000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)

View File

@ -25,7 +25,7 @@ class FirecrawlApp:
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
response = self._post_request(self._build_url("v2/scrape"), json_data, headers)
if response.status_code == 200:
response_data = response.json()
data = response_data["data"]
@ -42,7 +42,7 @@ class FirecrawlApp:
json_data = {"url": url}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
response = self._post_request(self._build_url("v2/crawl"), json_data, headers)
if response.status_code == 200:
# There's also another two fields in the response: "success" (bool) and "url" (str)
job_id = response.json().get("id")
@ -58,7 +58,7 @@ class FirecrawlApp:
if params:
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
response = self._post_request(self._build_url("v2/map"), json_data, headers)
if response.status_code == 200:
return cast(dict[str, Any], response.json())
elif response.status_code in {402, 409, 500, 429, 408}:
@ -69,7 +69,7 @@ class FirecrawlApp:
def check_crawl_status(self, job_id) -> dict[str, Any]:
headers = self._prepare_headers()
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed":
@ -120,6 +120,10 @@ class FirecrawlApp:
def _prepare_headers(self) -> dict[str, Any]:
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _build_url(self, path: str) -> str:
# ensure exactly one slash between base and path, regardless of user-provided base_url
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
for attempt in range(retries):
response = httpx.post(url, headers=headers, json=data)
@ -139,7 +143,11 @@ class FirecrawlApp:
return response
def _handle_error(self, response, action):
error_message = response.json().get("error", "Unknown error occurred")
try:
payload = response.json()
error_message = payload.get("error") or payload.get("message") or response.text or "Unknown error occurred"
except json.JSONDecodeError:
error_message = response.text or "Unknown error occurred"
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
@ -160,7 +168,7 @@ class FirecrawlApp:
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
response = self._post_request(self._build_url("v2/search"), json_data, headers)
if response.status_code == 200:
response_data = response.json()
if not response_data.get("success"):

View File

@ -48,13 +48,21 @@ class NotionExtractor(BaseExtractor):
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
if not self._notion_access_token:
try:
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
except Exception as e:
logger.warning(
(
"Failed to get Notion access token from datasource credentials: %s, "
"falling back to environment variable NOTION_INTEGRATION_TOKEN"
),
e,
)
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
)
) from e
self._notion_access_token = integration_token

View File

@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor):
def _extract_images_from_docx(self, doc):
image_count = 0
image_map = {}
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref:
@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor):
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use r_id as key for external images since target_part is undefined
image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
image_map[r_id] = f"![image]({base_url}/files/{upload_file.id}/file-preview)"
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor):
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use target_part as key for internal images
image_map[rel.target_part] = (
f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
)
image_map[rel.target_part] = f"![image]({base_url}/files/{upload_file.id}/file-preview)"
db.session.commit()
return image_map

View File

@ -231,7 +231,7 @@ class BaseIndexProcessor(ABC):
if not filename:
parsed_url = urlparse(image_url)
# unquote 处理 URL 中的中文
# Decode percent-encoded characters in the URL path.
path = unquote(parsed_url.path)
filename = os.path.basename(path)

View File

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@ -151,20 +151,14 @@ class DatasetRetrieval:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
for dataset in datasets:
if dataset.available_document_count == 0 and dataset.provider != "external":
continue
# pass if dataset is not available
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
else:
@ -282,26 +276,35 @@ class DatasetRetrieval:
)
context_files.append(attachment_info)
if show_retrieve_source:
dataset_ids = [record.segment.dataset_id for record in records]
document_ids = [record.segment.document_id for record in records]
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id.in_(document_ids),
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
dataset_stmt = select(Dataset).where(
Dataset.id.in_(dataset_ids),
)
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
dataset_item = dataset_map.get(segment.dataset_id)
document_item = document_map.get(segment.document_id)
if dataset_item and document_item:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
dataset_id=dataset_item.id,
dataset_name=dataset_item.name,
document_id=document_item.id,
document_name=document_item.name,
data_source_type=document_item.data_source_type,
segment_id=segment.id,
retriever_from=invoke_from.to_source(),
score=record.score or 0.0,
doc_metadata=document.doc_metadata,
doc_metadata=document_item.doc_metadata,
)
if invoke_from.to_source() == "dev":
@ -513,6 +516,9 @@ class DatasetRetrieval:
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
with measure_time() as timer:
cancel_event = threading.Event()
thread_exceptions: list[Exception] = []
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
@ -531,6 +537,8 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
)
all_threads.append(query_thread)
@ -554,12 +562,25 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
for thread in all_threads:
thread.join()
# Poll threads with short timeout to detect errors quickly (fail-fast)
while any(t.is_alive() for t in all_threads):
for thread in all_threads:
thread.join(timeout=0.1)
if thread_exceptions:
cancel_event.set()
break
if thread_exceptions:
break
if thread_exceptions:
raise thread_exceptions[0]
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
@ -1033,7 +1054,7 @@ class DatasetRetrieval:
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
self.process_metadata_filter_func(
sequence,
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
@ -1069,7 +1090,7 @@ class DatasetRetrieval:
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
filters = self.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -1165,8 +1186,9 @@ class DatasetRetrieval:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
@classmethod
def process_metadata_filter_func(
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return filters
@ -1215,6 +1237,20 @@ class DatasetRetrieval:
case "" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case "in" | "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
# `field in []` is False, `field not in []` is True
filters.append(literal(condition == "not in"))
else:
op = json_field.in_ if condition == "in" else json_field.notin_
filters.append(op(value_list))
case _:
pass
@ -1386,40 +1422,53 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
cancel_event: threading.Event | None = None,
thread_exceptions: list[Exception] | None = None,
):
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
try:
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
# Check for cancellation signal
if cancel_event and cancel_event.is_set():
break
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
# Poll threads with short timeout to respond quickly to cancellation
while any(t.is_alive() for t in threads):
for thread in threads:
thread.join(timeout=0.1)
if cancel_event and cancel_event.is_set():
break
if cancel_event and cancel_event.is_set():
break
if reranking_enable:
# do rerank for searched documents
@ -1452,3 +1501,8 @@ class DatasetRetrieval:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()
if thread_exceptions is not None:
thread_exceptions.append(e)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import codecs
import re
from typing import Any
@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = fixed_separator
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""]
def split_text(self, text: str) -> list[str]:
@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = re.split(r" +", text)
else:
splits = text.split(separator)
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
if self._keep_separator:
splits = [s + separator for s in splits[:-1]] + splits[-1:]
else:
splits = list(text)
if separator == "\n":
@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = separator if self._keep_separator else ""
_separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits)
if separator != "":
for s, s_len in zip(splits, s_lens):

View File

@ -153,11 +153,11 @@ class ToolInvokeMessage(BaseModel):
@classmethod
def transform_variable_value(cls, values):
"""
Only basic types and lists are allowed.
Only basic types, lists, and None are allowed.
"""
value = values.get("variable_value")
if not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types and lists are allowed.")
if value is not None and not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types, lists, and None are allowed.")
# if stream is true, the value must be a string
if values.get("stream"):

View File

@ -6,7 +6,15 @@ from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
)
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@ -53,10 +61,19 @@ class MCPTool(Tool):
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
elif isinstance(content, AudioContent):
yield self._process_audio_content(content)
elif isinstance(content, ImageContent | AudioContent):
yield self.create_blob_message(
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
)
elif isinstance(content, EmbeddedResource):
resource = content.resource
if isinstance(resource, TextResourceContents):
yield self.create_text_message(resource.text)
elif isinstance(resource, BlobResourceContents):
mime_type = resource.mimeType or "application/octet-stream"
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
else:
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
else:
logger.warning("Unsupported content type=%s", type(content))
@ -101,14 +118,6 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -47,33 +48,30 @@ class WorkflowToolProviderController(ToolProviderController):
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
with Session(db.engine, expire_on_commit=False) as session, session.begin():
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
if not provider:
raise ValueError("workflow provider not found")
app = session.get(App, provider.app_id)
with session_factory.create_session() as session, session.begin():
app = session.get(App, db_provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, provider.user_id) if provider.user_id else None
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user.name if user else "",
name=provider.label,
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
icon=provider.icon,
name=db_provider.label,
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
),
credentials_schema=[],
plugin_id=None,
),
provider_id=provider.id or "",
provider_id="",
)
controller.tools = [
controller._get_db_provider_tool(provider, app, session=session, user=user),
controller._get_db_provider_tool(db_provider, app, session=session, user=user),
]
return controller

View File

@ -67,12 +67,16 @@ def create_trigger_provider_encrypter_for_subscription(
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
cache = TriggerProviderCredentialsCache(
TriggerProviderCredentialsCache(
tenant_id=tenant_id,
provider_id=provider_id,
credential_id=subscription_id,
)
cache.delete()
).delete()
TriggerProviderPropertiesCache(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_id=subscription_id,
).delete()
def create_trigger_provider_encrypter_for_properties(

Some files were not shown because too many files have changed in this diff Show More