Merge branch 'main' into feat/agent-node-v2

This commit is contained in:
Novice 2025-12-30 10:20:42 +08:00
commit ccabdbc83b
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
232 changed files with 18692 additions and 2696 deletions

8
.claude/settings.json Normal file
View File

@ -0,0 +1,8 @@
{
"enabledPlugins": {
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true
}
}

View File

@ -1,19 +0,0 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -0,0 +1,483 @@
---
name: component-refactoring
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
---
# Dify Component Refactoring Skill
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
## Quick Reference
### Commands (run from `web/`)
Use paths relative to `web/` (e.g., `app/components/...`).
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
```bash
cd web
# Generate refactoring prompt
pnpm refactor-component <path>
# Output refactoring analysis as JSON
pnpm refactor-component <path> --json
# Generate testing prompt (after refactoring)
pnpm analyze-component <path>
# Output testing analysis as JSON
pnpm analyze-component <path> --json
```
### Complexity Analysis
```bash
# Analyze component complexity
pnpm analyze-component <path> --json
# Key metrics to check:
# - complexity: normalized score 0-100 (target < 50)
# - maxComplexity: highest single function complexity
# - lineCount: total lines (target < 300)
```
### Complexity Score Interpretation
| Score | Level | Action |
|-------|-------|--------|
| 0-25 | 🟢 Simple | Ready for testing |
| 26-50 | 🟡 Medium | Consider minor refactoring |
| 51-75 | 🟠 Complex | **Refactor before testing** |
| 76-100 | 🔴 Very Complex | **Must refactor** |
## Core Refactoring Patterns
### Pattern 1: Extract Custom Hooks
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
```typescript
// ❌ Before: Complex state logic in component
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// 50+ lines of state management logic...
return <div>...</div>
}
// ✅ After: Extract to custom hook
// hooks/use-model-config.ts
export const useModelConfig = (appId: string) => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// Related state management logic here
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
}
// Component becomes cleaner
const Configuration: FC = () => {
const { modelConfig, setModelConfig } = useModelConfig(appId)
return <div>...</div>
}
```
**Dify Examples**:
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
- `web/app/components/app/configuration/debug/hooks.tsx`
- `web/app/components/workflow/hooks/use-workflow.ts`
### Pattern 2: Extract Sub-Components
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
```typescript
// ❌ Before: Monolithic JSX with multiple sections
const AppInfo = () => {
return (
<div>
{/* 100 lines of header UI */}
{/* 100 lines of operations UI */}
{/* 100 lines of modals */}
</div>
)
}
// ✅ After: Split into focused components
// app-info/
// ├── index.tsx (orchestration only)
// ├── app-header.tsx (header UI)
// ├── app-operations.tsx (operations UI)
// └── app-modals.tsx (modal management)
const AppInfo = () => {
const { showModal, setShowModal } = useAppInfoModals()
return (
<div>
<AppHeader appDetail={appDetail} />
<AppOperations onAction={handleAction} />
<AppModals show={showModal} onClose={() => setShowModal(null)} />
</div>
)
}
```
**Dify Examples**:
- `web/app/components/app/configuration/` directory structure
- `web/app/components/workflow/nodes/` per-node organization
### Pattern 3: Simplify Conditional Logic
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
```typescript
// ❌ Before: Deeply nested conditionals
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh />
case LanguagesSupported[7]:
return <TemplateChatJa />
default:
return <TemplateChatEn />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
// Another 15 lines...
}
// More conditions...
}, [appDetail, locale])
// ✅ After: Use lookup tables + early returns
const TEMPLATE_MAP = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
// ...
},
}
const Template = useMemo(() => {
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
if (!modeTemplates) return null
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
### Pattern 4: Extract API/Data Logic
**When**: Component directly handles API calls, data transformation, or complex async operations.
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
```typescript
// ❌ Before: API logic in component
const MCPServiceCard = () => {
const [basicAppConfig, setBasicAppConfig] = useState({})
useEffect(() => {
if (isBasicApp && appId) {
(async () => {
const res = await fetchAppDetail({ url: '/apps', id: appId })
setBasicAppConfig(res?.model_config || {})
})()
}
}, [appId, isBasicApp])
// More API-related logic...
}
// ✅ After: Extract to data hook using React Query
// use-app-config.ts
import { useQuery } from '@tanstack/react-query'
import { get } from '@/service/base'
const NAME_SPACE = 'appConfig'
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
return useQuery({
enabled: isBasicApp && !!appId,
queryKey: [NAME_SPACE, 'detail', appId],
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || {},
})
}
// Component becomes cleaner
const MCPServiceCard = () => {
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
// UI only
}
```
**React Query Best Practices in Dify**:
- Define `NAME_SPACE` for query key organization
- Use `enabled` option for conditional fetching
- Use `select` for data transformation
- Export invalidation hooks: `useInvalidXxx`
**Dify Examples**:
- `web/service/use-workflow.ts`
- `web/service/use-common.ts`
- `web/service/knowledge/use-dataset.ts`
- `web/service/knowledge/use-document.ts`
### Pattern 5: Extract Modal/Dialog Management
**When**: Component manages multiple modals with complex open/close states.
**Dify Convention**: Modals should be extracted with their state management.
```typescript
// ❌ Before: Multiple modal states in component
const AppInfo = () => {
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
const [showSwitchModal, setShowSwitchModal] = useState(false)
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
// 5+ more modal states...
}
// ✅ After: Extract to modal management hook
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
const useAppInfoModals = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
const closeModal = useCallback(() => setActiveModal(null), [])
return {
activeModal,
openModal,
closeModal,
isOpen: (type: ModalType) => activeModal === type,
}
}
```
### Pattern 6: Extract Form Logic
**When**: Complex form validation, submission handling, or field transformation.
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
```typescript
// ✅ Use existing form infrastructure
import { useAppForm } from '@/app/components/base/form'
const ConfigForm = () => {
const form = useAppForm({
defaultValues: { name: '', description: '' },
onSubmit: handleSubmit,
})
return <form.Provider>...</form.Provider>
}
```
## Dify-Specific Refactoring Guidelines
### 1. Context Provider Extraction
**When**: Component provides complex context values with multiple states.
```typescript
// ❌ Before: Large context value object
const value = {
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
// 50+ more properties...
}
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
// ✅ After: Split into domain-specific contexts
<ModelConfigProvider value={modelConfigValue}>
<DatasetConfigProvider value={datasetConfigValue}>
<UIConfigProvider value={uiConfigValue}>
{children}
</UIConfigProvider>
</DatasetConfigProvider>
</ModelConfigProvider>
```
**Dify Reference**: `web/context/` directory structure
### 2. Workflow Node Components
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
**Conventions**:
- Keep node logic in `use-interactions.ts`
- Extract panel UI to separate files
- Use `_base` components for common patterns
```
nodes/<node-type>/
├── index.tsx # Node registration
├── node.tsx # Node visual component
├── panel.tsx # Configuration panel
├── use-interactions.ts # Node-specific hooks
└── types.ts # Type definitions
```
### 3. Configuration Components
**When**: Refactoring app configuration components.
**Conventions**:
- Separate config sections into subdirectories
- Use existing patterns from `web/app/components/app/configuration/`
- Keep feature toggles in dedicated components
### 4. Tool/Plugin Components
**When**: Refactoring tool-related components (`web/app/components/tools/`).
**Conventions**:
- Follow existing modal patterns
- Use service hooks from `web/service/use-tools.ts`
- Keep provider-specific logic isolated
## Refactoring Workflow
### Step 1: Generate Refactoring Prompt
```bash
pnpm refactor-component <path>
```
This command will:
- Analyze component complexity and features
- Identify specific refactoring actions needed
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
- Provide detailed requirements based on detected patterns
### Step 2: Analyze Details
```bash
pnpm analyze-component <path> --json
```
Identify:
- Total complexity score
- Max function complexity
- Line count
- Features detected (state, effects, API, etc.)
### Step 3: Plan
Create a refactoring plan based on detected features:
| Detected Feature | Refactoring Action |
|------------------|-------------------|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
| `hasAPI: true` | Extract data/service hook |
| `hasEvents: true` (many) | Extract event handlers |
| `lineCount > 300` | Split into sub-components |
| `maxComplexity > 50` | Simplify conditional logic |
### Step 4: Execute Incrementally
1. **Extract one piece at a time**
2. **Run lint, type-check, and tests after each extraction**
3. **Verify functionality before next step**
```
For each extraction:
┌────────────────────────────────────────┐
│ 1. Extract code │
│ 2. Run: pnpm lint:fix │
│ 3. Run: pnpm type-check:tsgo │
│ 4. Run: pnpm test │
│ 5. Test functionality manually │
│ 6. PASS? → Next extraction │
│ FAIL? → Fix before continuing │
└────────────────────────────────────────┘
```
### Step 5: Verify
After refactoring:
```bash
# Re-run refactor command to verify improvements
pnpm refactor-component <path>
# If complexity < 25 and lines < 200, you'll see:
# ✅ COMPONENT IS WELL-STRUCTURED
# For detailed metrics:
pnpm analyze-component <path> --json
# Target metrics:
# - complexity < 50
# - lineCount < 300
# - maxComplexity < 30
```
## Common Mistakes to Avoid
### ❌ Over-Engineering
```typescript
// ❌ Too many tiny hooks
const useButtonText = () => useState('Click')
const useButtonDisabled = () => useState(false)
const useButtonLoading = () => useState(false)
// ✅ Cohesive hook with related state
const useButtonState = () => {
const [text, setText] = useState('Click')
const [disabled, setDisabled] = useState(false)
const [loading, setLoading] = useState(false)
return { text, setText, disabled, setDisabled, loading, setLoading }
}
```
### ❌ Breaking Existing Patterns
- Follow existing directory structures
- Maintain naming conventions
- Preserve export patterns for compatibility
### ❌ Premature Abstraction
- Only extract when there's clear complexity benefit
- Don't create abstractions for single-use code
- Keep refactored code in the same domain area
## References
### Dify Codebase Examples
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
- **Component splitting**: `web/app/components/app/configuration/`
- **Service hooks**: `web/service/use-*.ts`
- **Workflow patterns**: `web/app/components/workflow/hooks/`
- **Form patterns**: `web/app/components/base/form/`
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/testing/testing.md` - Testing specification

View File

@ -0,0 +1,493 @@
# Complexity Reduction Patterns
This document provides patterns for reducing cognitive complexity in Dify React components.
## Understanding Complexity
### SonarJS Cognitive Complexity
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
- **Total Complexity**: Sum of all functions' complexity in the file
- **Max Complexity**: Highest single function complexity
### What Increases Complexity
| Pattern | Complexity Impact |
|---------|-------------------|
| `if/else` | +1 per branch |
| Nested conditions | +1 per nesting level |
| `switch/case` | +1 per case |
| `for/while/do` | +1 per loop |
| `&&`/`||` chains | +1 per operator |
| Nested callbacks | +1 per nesting level |
| `try/catch` | +1 per catch |
| Ternary expressions | +1 per nesting |
## Pattern 1: Replace Conditionals with Lookup Tables
**Before** (complexity: ~15):
```typescript
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateChatJa appDetail={appDetail} />
default:
return <TemplateChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateAdvancedChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateAdvancedChatJa appDetail={appDetail} />
default:
return <TemplateAdvancedChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
// Similar pattern...
}
return null
}, [appDetail, locale])
```
**After** (complexity: ~3):
```typescript
// Define lookup table outside component
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
default: TemplateAdvancedChatEn,
},
[AppModeEnum.WORKFLOW]: {
[LanguagesSupported[1]]: TemplateWorkflowZh,
[LanguagesSupported[7]]: TemplateWorkflowJa,
default: TemplateWorkflowEn,
},
// ...
}
// Clean component logic
const Template = useMemo(() => {
if (!appDetail?.mode) return null
const templates = TEMPLATE_MAP[appDetail.mode]
if (!templates) return null
const TemplateComponent = templates[locale] ?? templates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
## Pattern 2: Use Early Returns
**Before** (complexity: ~10):
```typescript
const handleSubmit = () => {
if (isValid) {
if (hasChanges) {
if (isConnected) {
submitData()
} else {
showConnectionError()
}
} else {
showNoChangesMessage()
}
} else {
showValidationError()
}
}
```
**After** (complexity: ~4):
```typescript
const handleSubmit = () => {
if (!isValid) {
showValidationError()
return
}
if (!hasChanges) {
showNoChangesMessage()
return
}
if (!isConnected) {
showConnectionError()
return
}
submitData()
}
```
## Pattern 3: Extract Complex Conditions
**Before** (complexity: high):
```typescript
const canPublish = (() => {
if (mode !== AppModeEnum.COMPLETION) {
if (!isAdvancedMode)
return true
if (modelModeType === ModelModeType.completion) {
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
return false
return true
}
return true
}
return !promptEmpty
})()
```
**After** (complexity: lower):
```typescript
// Extract to named functions
const canPublishInCompletionMode = () => !promptEmpty
const canPublishInChatMode = () => {
if (!isAdvancedMode) return true
if (modelModeType !== ModelModeType.completion) return true
return hasSetBlockStatus.history && hasSetBlockStatus.query
}
// Clean main logic
const canPublish = mode === AppModeEnum.COMPLETION
? canPublishInCompletionMode()
: canPublishInChatMode()
```
## Pattern 4: Replace Chained Ternaries
**Before** (complexity: ~5):
```typescript
const statusText = serverActivated
? t('status.running')
: serverPublished
? t('status.inactive')
: appUnpublished
? t('status.unpublished')
: t('status.notConfigured')
```
**After** (complexity: ~2):
```typescript
const getStatusText = () => {
if (serverActivated) return t('status.running')
if (serverPublished) return t('status.inactive')
if (appUnpublished) return t('status.unpublished')
return t('status.notConfigured')
}
const statusText = getStatusText()
```
Or use lookup:
```typescript
const STATUS_TEXT_MAP = {
running: 'status.running',
inactive: 'status.inactive',
unpublished: 'status.unpublished',
notConfigured: 'status.notConfigured',
} as const
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
if (serverActivated) return 'running'
if (serverPublished) return 'inactive'
if (appUnpublished) return 'unpublished'
return 'notConfigured'
}
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
```
## Pattern 5: Flatten Nested Loops
**Before** (complexity: high):
```typescript
const processData = (items: Item[]) => {
const results: ProcessedItem[] = []
for (const item of items) {
if (item.isValid) {
for (const child of item.children) {
if (child.isActive) {
for (const prop of child.properties) {
if (prop.value !== null) {
results.push({
itemId: item.id,
childId: child.id,
propValue: prop.value,
})
}
}
}
}
}
}
return results
}
```
**After** (complexity: lower):
```typescript
// Use functional approach
const processData = (items: Item[]) => {
return items
.filter(item => item.isValid)
.flatMap(item =>
item.children
.filter(child => child.isActive)
.flatMap(child =>
child.properties
.filter(prop => prop.value !== null)
.map(prop => ({
itemId: item.id,
childId: child.id,
propValue: prop.value,
}))
)
)
}
```
## Pattern 6: Extract Event Handler Logic
**Before** (complexity: high in component):
```typescript
const Component = () => {
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
let newDatasets = data
if (data.find(item => !item.name)) {
const newSelected = produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const newItem = dataSets.find(i => i.id === item.id)
if (newItem)
draft[index] = newItem
}
})
})
setDataSets(newSelected)
newDatasets = newSelected
}
else {
setDataSets(data)
}
hideSelectDataSet()
// 40 more lines of logic...
}
return <div>...</div>
}
```
**After** (complexity: lower):
```typescript
// Extract to hook or utility
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
const normalizeSelection = (data: DataSet[]) => {
const hasUnloadedItem = data.some(item => !item.name)
if (!hasUnloadedItem) return data
return produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const existing = dataSets.find(i => i.id === item.id)
if (existing) draft[index] = existing
}
})
})
}
const hasSelectionChanged = (newData: DataSet[]) => {
return !isEqual(
newData.map(item => item.id),
dataSets.map(item => item.id)
)
}
return { normalizeSelection, hasSelectionChanged }
}
// Component becomes cleaner
const Component = () => {
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
const handleSelect = (data: DataSet[]) => {
if (!hasSelectionChanged(data)) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
const normalized = normalizeSelection(data)
setDataSets(normalized)
hideSelectDataSet()
}
return <div>...</div>
}
```
## Pattern 7: Reduce Boolean Logic Complexity
**Before** (complexity: ~8):
```typescript
const toggleDisabled = hasInsufficientPermissions
|| appUnpublished
|| missingStartNode
|| triggerModeDisabled
|| (isAdvancedApp && !currentWorkflow?.graph)
|| (isBasicApp && !basicAppConfig.updated_at)
```
**After** (complexity: ~3):
```typescript
// Extract meaningful boolean functions
const isAppReady = () => {
if (isAdvancedApp) return !!currentWorkflow?.graph
return !!basicAppConfig.updated_at
}
const hasRequiredPermissions = () => {
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
}
const canToggle = () => {
if (!hasRequiredPermissions()) return false
if (!isAppReady()) return false
if (missingStartNode) return false
if (triggerModeDisabled) return false
return true
}
const toggleDisabled = !canToggle()
```
## Pattern 8: Simplify useMemo/useCallback Dependencies
**Before** (complexity: multiple recalculations):
```typescript
const payload = useMemo(() => {
let parameters: Parameter[] = []
let outputParameters: OutputParameter[] = []
if (!published) {
parameters = (inputs || []).map((item) => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
outputParameters = (outputs || []).map((item) => ({
name: item.variable,
description: '',
type: item.value_type,
}))
}
else if (detail && detail.tool) {
parameters = (inputs || []).map((item) => ({
// Complex transformation...
}))
outputParameters = (outputs || []).map((item) => ({
// Complex transformation...
}))
}
return {
icon: detail?.icon || icon,
label: detail?.label || name,
// ...more fields
}
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
```
**After** (complexity: separated concerns):
```typescript
// Separate transformations
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
return useMemo(() => {
if (!published) {
return inputs.map(item => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
}
if (!detail?.tool) return []
return inputs.map(item => ({
name: item.variable,
required: item.required,
type: item.type === 'paragraph' ? 'string' : item.type,
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
}))
}, [inputs, detail, published])
}
// Component uses hook
const parameters = useParameterTransform(inputs, detail, published)
const outputParameters = useOutputTransform(outputs, detail, published)
const payload = useMemo(() => ({
icon: detail?.icon || icon,
label: detail?.label || name,
parameters,
outputParameters,
// ...
}), [detail, icon, name, parameters, outputParameters])
```
## Target Metrics After Refactoring
| Metric | Target |
|--------|--------|
| Total Complexity | < 50 |
| Max Function Complexity | < 30 |
| Function Length | < 30 lines |
| Nesting Depth | ≤ 3 levels |
| Conditional Chains | ≤ 3 conditions |

View File

@ -0,0 +1,477 @@
# Component Splitting Patterns
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
## When to Split Components
Split a component when you identify:
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
1. **Repeated patterns** - Similar UI structures used multiple times
1. **300+ lines** - Component exceeds manageable size
1. **Modal clusters** - Multiple modals rendered in one component
## Splitting Strategies
### Strategy 1: Section-Based Splitting
Identify visual sections and extract each as a component.
```typescript
// ❌ Before: Monolithic component (500+ lines)
const ConfigurationPage = () => {
return (
<div>
{/* Header Section - 50 lines */}
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher ... />
</div>
</div>
{/* Config Section - 200 lines */}
<div className="config">
<Config />
</div>
{/* Debug Section - 150 lines */}
<div className="debug">
<Debug ... />
</div>
{/* Modals Section - 100 lines */}
{showSelectDataSet && <SelectDataSet ... />}
{showHistoryModal && <EditHistoryModal ... />}
{showUseGPT4Confirm && <Confirm ... />}
</div>
)
}
// ✅ After: Split into focused components
// configuration/
// ├── index.tsx (orchestration)
// ├── configuration-header.tsx
// ├── configuration-content.tsx
// ├── configuration-debug.tsx
// └── configuration-modals.tsx
// configuration-header.tsx
interface ConfigurationHeaderProps {
isAdvancedMode: boolean
onPublish: () => void
}
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
isAdvancedMode,
onPublish,
}) => {
const { t } = useTranslation()
return (
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher onPublish={onPublish} />
</div>
</div>
)
}
// index.tsx (orchestration only)
const ConfigurationPage = () => {
const { modelConfig, setModelConfig } = useModelConfig()
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
<ConfigurationHeader
isAdvancedMode={isAdvancedMode}
onPublish={handlePublish}
/>
<ConfigurationContent
modelConfig={modelConfig}
onConfigChange={setModelConfig}
/>
{!isMobile && (
<ConfigurationDebug
inputs={inputs}
onSetting={handleSetting}
/>
)}
<ConfigurationModals
activeModal={activeModal}
onClose={closeModal}
/>
</div>
)
}
```
### Strategy 2: Conditional Block Extraction
Extract large conditional rendering blocks.
```typescript
// ❌ Before: Large conditional blocks
const AppInfo = () => {
return (
<div>
{expand ? (
<div className="expanded">
{/* 100 lines of expanded view */}
</div>
) : (
<div className="collapsed">
{/* 50 lines of collapsed view */}
</div>
)}
</div>
)
}
// ✅ After: Separate view components
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="expanded">
{/* Clean, focused expanded view */}
</div>
)
}
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="collapsed">
{/* Clean, focused collapsed view */}
</div>
)
}
const AppInfo = () => {
return (
<div>
{expand
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
}
</div>
)
}
```
### Strategy 3: Modal Extraction
Extract modals with their trigger logic.
```typescript
// ❌ Before: Multiple modals in one component
const AppInfo = () => {
const [showEdit, setShowEdit] = useState(false)
const [showDuplicate, setShowDuplicate] = useState(false)
const [showDelete, setShowDelete] = useState(false)
const [showSwitch, setShowSwitch] = useState(false)
const onEdit = async (data) => { /* 20 lines */ }
const onDuplicate = async (data) => { /* 20 lines */ }
const onDelete = async () => { /* 15 lines */ }
return (
<div>
{/* Main content */}
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
{showSwitch && <SwitchModal ... />}
</div>
)
}
// ✅ After: Modal manager component
// app-info-modals.tsx
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
interface AppInfoModalsProps {
appDetail: AppDetail
activeModal: ModalType
onClose: () => void
onSuccess: () => void
}
const AppInfoModals: FC<AppInfoModalsProps> = ({
appDetail,
activeModal,
onClose,
onSuccess,
}) => {
const handleEdit = async (data) => { /* logic */ }
const handleDuplicate = async (data) => { /* logic */ }
const handleDelete = async () => { /* logic */ }
return (
<>
{activeModal === 'edit' && (
<EditModal
appDetail={appDetail}
onConfirm={handleEdit}
onClose={onClose}
/>
)}
{activeModal === 'duplicate' && (
<DuplicateModal
appDetail={appDetail}
onConfirm={handleDuplicate}
onClose={onClose}
/>
)}
{activeModal === 'delete' && (
<DeleteConfirm
onConfirm={handleDelete}
onClose={onClose}
/>
)}
{activeModal === 'switch' && (
<SwitchModal
appDetail={appDetail}
onClose={onClose}
/>
)}
</>
)
}
// Parent component
const AppInfo = () => {
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
{/* Main content with openModal triggers */}
<Button onClick={() => openModal('edit')}>Edit</Button>
<AppInfoModals
appDetail={appDetail}
activeModal={activeModal}
onClose={closeModal}
onSuccess={handleSuccess}
/>
</div>
)
}
```
### Strategy 4: List Item Extraction
Extract repeated item rendering.
```typescript
// ❌ Before: Inline item rendering
const OperationsList = () => {
return (
<div>
{operations.map(op => (
<div key={op.id} className="operation-item">
<span className="icon">{op.icon}</span>
<span className="title">{op.title}</span>
<span className="description">{op.description}</span>
<button onClick={() => op.onClick()}>
{op.actionLabel}
</button>
{op.badge && <Badge>{op.badge}</Badge>}
{/* More complex rendering... */}
</div>
))}
</div>
)
}
// ✅ After: Extracted item component
interface OperationItemProps {
operation: Operation
onAction: (id: string) => void
}
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
return (
<div className="operation-item">
<span className="icon">{operation.icon}</span>
<span className="title">{operation.title}</span>
<span className="description">{operation.description}</span>
<button onClick={() => onAction(operation.id)}>
{operation.actionLabel}
</button>
{operation.badge && <Badge>{operation.badge}</Badge>}
</div>
)
}
const OperationsList = () => {
const handleAction = useCallback((id: string) => {
const op = operations.find(o => o.id === id)
op?.onClick()
}, [operations])
return (
<div>
{operations.map(op => (
<OperationItem
key={op.id}
operation={op}
onAction={handleAction}
/>
))}
</div>
)
}
```
## Directory Structure Patterns
### Pattern A: Flat Structure (Simple Components)
For components with 2-3 sub-components:
```
component-name/
├── index.tsx # Main component
├── sub-component-a.tsx
├── sub-component-b.tsx
└── types.ts # Shared types
```
### Pattern B: Nested Structure (Complex Components)
For components with many sub-components:
```
component-name/
├── index.tsx # Main orchestration
├── types.ts # Shared types
├── hooks/
│ ├── use-feature-a.ts
│ └── use-feature-b.ts
├── components/
│ ├── header/
│ │ └── index.tsx
│ ├── content/
│ │ └── index.tsx
│ └── modals/
│ └── index.tsx
└── utils/
└── helpers.ts
```
### Pattern C: Feature-Based Structure (Dify Standard)
Following Dify's existing patterns:
```
configuration/
├── index.tsx # Main page component
├── base/ # Base/shared components
│ ├── feature-panel/
│ ├── group-name/
│ └── operation-btn/
├── config/ # Config section
│ ├── index.tsx
│ ├── agent/
│ └── automatic/
├── dataset-config/ # Dataset section
│ ├── index.tsx
│ ├── card-item/
│ └── params-config/
├── debug/ # Debug section
│ ├── index.tsx
│ └── hooks.tsx
└── hooks/ # Shared hooks
└── use-advanced-prompt-config.ts
```
## Props Design
### Minimal Props Principle
Pass only what's needed:
```typescript
// ❌ Bad: Passing entire objects when only some fields needed
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
// ✅ Good: Destructure to minimum required
<ConfigHeader
appName={appDetail.name}
isAdvancedMode={modelConfig.isAdvanced}
onPublish={handlePublish}
/>
```
### Callback Props Pattern
Use callbacks for child-to-parent communication:
```typescript
// Parent
const Parent = () => {
const [value, setValue] = useState('')
return (
<Child
value={value}
onChange={setValue}
onSubmit={handleSubmit}
/>
)
}
// Child
interface ChildProps {
value: string
onChange: (value: string) => void
onSubmit: () => void
}
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
return (
<div>
<input value={value} onChange={e => onChange(e.target.value)} />
<button onClick={onSubmit}>Submit</button>
</div>
)
}
```
### Render Props for Flexibility
When sub-components need parent context:
```typescript
interface ListProps<T> {
items: T[]
renderItem: (item: T, index: number) => React.ReactNode
renderEmpty?: () => React.ReactNode
}
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
if (items.length === 0 && renderEmpty) {
return <>{renderEmpty()}</>
}
return (
<div>
{items.map((item, index) => renderItem(item, index))}
</div>
)
}
// Usage
<List
items={operations}
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
renderEmpty={() => <EmptyState message="No operations" />}
/>
```

View File

@ -0,0 +1,317 @@
# Hook Extraction Patterns
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
## When to Extract Hooks
Extract a custom hook when you identify:
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
1. **Business logic** - Data transformations, validations, or calculations
1. **Reusable patterns** - Logic that appears in multiple components
## Extraction Process
### Step 1: Identify State Groups
Look for state variables that are logically related:
```typescript
// ❌ These belong together - extract to hook
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
// These are model-related state that should be in useModelConfig()
```
### Step 2: Identify Related Effects
Find effects that modify the grouped state:
```typescript
// ❌ These effects belong with the state above
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties.mode
if (mode) {
const newModelConfig = produce(modelConfig, (draft) => {
draft.mode = mode
})
setModelConfig(newModelConfig)
}
}
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
```
### Step 3: Create the Hook
```typescript
// hooks/use-model-config.ts
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { ModelConfig } from '@/models/debug'
import { produce } from 'immer'
import { useEffect, useState } from 'react'
import { ModelModeType } from '@/types/app'
interface UseModelConfigParams {
initialConfig?: Partial<ModelConfig>
currModel?: { model_properties?: { mode?: ModelModeType } }
hasFetchedDetail: boolean
}
interface UseModelConfigReturn {
modelConfig: ModelConfig
setModelConfig: (config: ModelConfig) => void
completionParams: FormValue
setCompletionParams: (params: FormValue) => void
modelModeType: ModelModeType
}
export const useModelConfig = ({
initialConfig,
currModel,
hasFetchedDetail,
}: UseModelConfigParams): UseModelConfigReturn => {
const [modelConfig, setModelConfig] = useState<ModelConfig>({
provider: 'langgenius/openai/openai',
model_id: 'gpt-3.5-turbo',
mode: ModelModeType.unset,
// ... default values
...initialConfig,
})
const [completionParams, setCompletionParams] = useState<FormValue>({})
const modelModeType = modelConfig.mode
// Fill old app data missing model mode
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties?.mode
if (mode) {
setModelConfig(produce(modelConfig, (draft) => {
draft.mode = mode
}))
}
}
}, [hasFetchedDetail, modelModeType, currModel])
return {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
}
}
```
### Step 4: Update Component
```typescript
// Before: 50+ lines of state management
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
// ... lots of related state and effects
}
// After: Clean component
const Configuration: FC = () => {
const {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
} = useModelConfig({
currModel,
hasFetchedDetail,
})
// Component now focuses on UI
}
```
## Naming Conventions
### Hook Names
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
- Include domain: `useWorkflowVariables`, `useMCPServer`
### File Names
- Kebab-case: `use-model-config.ts`
- Place in `hooks/` subdirectory when multiple hooks exist
- Place alongside component for single-use hooks
### Return Type Names
- Suffix with `Return`: `UseModelConfigReturn`
- Suffix params with `Params`: `UseModelConfigParams`
## Common Hook Patterns in Dify
### 1. Data Fetching Hook (React Query)
```typescript
// Pattern: Use @tanstack/react-query for data fetching
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { get } from '@/service/base'
import { useInvalid } from '@/service/use-base'
const NAME_SPACE = 'appConfig'
// Query keys for cache management
export const appConfigQueryKeys = {
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
}
// Main data hook
export const useAppConfig = (appId: string) => {
return useQuery({
enabled: !!appId,
queryKey: appConfigQueryKeys.detail(appId),
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || null,
})
}
// Invalidation hook for refreshing data
export const useInvalidAppConfig = () => {
return useInvalid([NAME_SPACE])
}
// Usage in component
const Component = () => {
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
const invalidAppConfig = useInvalidAppConfig()
const handleRefresh = () => {
invalidAppConfig() // Invalidates cache and triggers refetch
}
return <div>...</div>
}
```
### 2. Form State Hook
```typescript
// Pattern: Form state + validation + submission
export const useConfigForm = (initialValues: ConfigFormValues) => {
const [values, setValues] = useState(initialValues)
const [errors, setErrors] = useState<Record<string, string>>({})
const [isSubmitting, setIsSubmitting] = useState(false)
const validate = useCallback(() => {
const newErrors: Record<string, string> = {}
if (!values.name) newErrors.name = 'Name is required'
setErrors(newErrors)
return Object.keys(newErrors).length === 0
}, [values])
const handleChange = useCallback((field: string, value: any) => {
setValues(prev => ({ ...prev, [field]: value }))
}, [])
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
if (!validate()) return
setIsSubmitting(true)
try {
await onSubmit(values)
} finally {
setIsSubmitting(false)
}
}, [values, validate])
return { values, errors, isSubmitting, handleChange, handleSubmit }
}
```
### 3. Modal State Hook
```typescript
// Pattern: Multiple modal management
type ModalType = 'edit' | 'delete' | 'duplicate' | null
export const useModalState = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const [modalData, setModalData] = useState<any>(null)
const openModal = useCallback((type: ModalType, data?: any) => {
setActiveModal(type)
setModalData(data)
}, [])
const closeModal = useCallback(() => {
setActiveModal(null)
setModalData(null)
}, [])
return {
activeModal,
modalData,
openModal,
closeModal,
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
}
}
```
### 4. Toggle/Boolean Hook
```typescript
// Pattern: Boolean state with convenience methods
export const useToggle = (initialValue = false) => {
const [value, setValue] = useState(initialValue)
const toggle = useCallback(() => setValue(v => !v), [])
const setTrue = useCallback(() => setValue(true), [])
const setFalse = useCallback(() => setValue(false), [])
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
}
// Usage
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
```
## Testing Extracted Hooks
After extraction, test hooks in isolation:
```typescript
// use-model-config.spec.ts
import { renderHook, act } from '@testing-library/react'
import { useModelConfig } from './use-model-config'
describe('useModelConfig', () => {
it('should initialize with default values', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: false,
}))
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
expect(result.current.modelModeType).toBe(ModelModeType.unset)
})
it('should update model config', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: true,
}))
act(() => {
result.current.setModelConfig({
...result.current.modelConfig,
model_id: 'gpt-4',
})
})
expect(result.current.modelConfig.model_id).toBe('gpt-4')
})
})
```

View File

@ -1,13 +1,13 @@
--- ---
name: Dify Frontend Testing name: frontend-testing
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests. description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
--- ---
# Dify Frontend Testing Skill # Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification. > **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill ## When to Apply This Skill
@ -15,7 +15,7 @@ Apply this skill when the user:
- Asks to **write tests** for a component, hook, or utility - Asks to **write tests** for a component, hook, or utility
- Asks to **review existing tests** for completeness - Asks to **review existing tests** for completeness
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files** - Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
- Requests **test coverage** improvement - Requests **test coverage** improvement
- Uses `pnpm analyze-component` output as context - Uses `pnpm analyze-component` output as context
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code - Mentions **testing**, **unit tests**, or **integration tests** for frontend code
@ -33,9 +33,9 @@ Apply this skill when the user:
| Tool | Version | Purpose | | Tool | Version | Purpose |
|------|---------|---------| |------|---------|---------|
| Jest | 29.7 | Test runner | | Vitest | 4.0.16 | Test runner |
| React Testing Library | 16.0 | Component testing | | React Testing Library | 16.0 | Component testing |
| happy-dom | - | Test environment | | jsdom | - | Test environment |
| nock | 14.0 | HTTP mocking | | nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety | | TypeScript | 5.x | Type safety |
@ -46,13 +46,13 @@ Apply this skill when the user:
pnpm test pnpm test
# Watch mode # Watch mode
pnpm test -- --watch pnpm test:watch
# Run specific file # Run specific file
pnpm test -- path/to/file.spec.tsx pnpm test path/to/file.spec.tsx
# Generate coverage report # Generate coverage report
pnpm test -- --coverage pnpm test:coverage
# Analyze component complexity # Analyze component complexity
pnpm analyze-component <path> pnpm analyze-component <path>
@ -77,9 +77,9 @@ import Component from './index'
// import { ChildComponent } from './child-component' // import { ChildComponent } from './child-component'
// ✅ Mock external dependencies only // ✅ Mock external dependencies only
jest.mock('@/service/api') vi.mock('@/service/api')
jest.mock('next/navigation', () => ({ vi.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }), useRouter: () => ({ push: vi.fn() }),
usePathname: () => '/test', usePathname: () => '/test',
})) }))
@ -88,7 +88,7 @@ let mockSharedState = false
describe('ComponentName', () => { describe('ComponentName', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
mockSharedState = false // ✅ Reset shared state mockSharedState = false // ✅ Reset shared state
}) })
@ -117,7 +117,7 @@ describe('ComponentName', () => {
// User Interactions // User Interactions
describe('User Interactions', () => { describe('User Interactions', () => {
it('should handle click events', () => { it('should handle click events', () => {
const handleClick = jest.fn() const handleClick = vi.fn()
render(<Component onClick={handleClick} />) render(<Component onClick={handleClick} />)
fireEvent.click(screen.getByRole('button')) fireEvent.click(screen.getByRole('button'))
@ -155,7 +155,7 @@ describe('ComponentName', () => {
For each file: For each file:
┌────────────────────────────────────────┐ ┌────────────────────────────────────────┐
│ 1. Write test │ │ 1. Write test │
│ 2. Run: pnpm test -- <file>.spec.tsx │ │ 2. Run: pnpm test <file>.spec.tsx
│ 3. PASS? → Mark complete, next file │ │ 3. PASS? → Mark complete, next file │
│ FAIL? → Fix first, then continue │ │ FAIL? → Fix first, then continue │
└────────────────────────────────────────┘ └────────────────────────────────────────┘
@ -178,7 +178,7 @@ Process in this order for multi-file testing:
- **500+ lines**: Consider splitting before testing - **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first - **Many dependencies**: Extract logic into hooks first
> 📖 See `guides/workflow.md` for complete workflow details and todo list format. > 📖 See `references/workflow.md` for complete workflow details and todo list format.
## Testing Strategy ## Testing Strategy
@ -289,17 +289,18 @@ For each test file generated, aim for:
- ✅ **>95%** branch coverage - ✅ **>95%** branch coverage
- ✅ **>95%** line coverage - ✅ **>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`. > **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
## Detailed Guides ## Detailed Guides
For more detailed information, refer to: For more detailed information, refer to:
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing) - `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `guides/mocking.md` - Mock patterns and best practices - `references/mocking.md` - Mock patterns and best practices
- `guides/async-testing.md` - Async operations and API calls - `references/async-testing.md` - Async operations and API calls
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing - `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `guides/common-patterns.md` - Frequently used testing patterns - `references/common-patterns.md` - Frequently used testing patterns
- `references/checklist.md` - Test generation checklist and validation steps
## Authoritative References ## Authoritative References
@ -315,7 +316,7 @@ For more detailed information, refer to:
### Project Configuration ### Project Configuration
- `web/jest.config.ts` - Jest configuration - `web/vitest.config.ts` - Vitest configuration
- `web/jest.setup.ts` - Test environment setup - `web/vitest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool - `web/scripts/analyze-component.js` - Component analysis tool
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations) - Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -23,14 +23,14 @@ import userEvent from '@testing-library/user-event'
// ============================================================================ // ============================================================================
// Mocks // Mocks
// ============================================================================ // ============================================================================
// WHY: Mocks must be hoisted to top of file (Jest requirement). // WHY: Mocks must be hoisted to top of file (Vitest requirement).
// They run BEFORE imports, so keep them before component imports. // They run BEFORE imports, so keep them before component imports.
// i18n (automatically mocked) // i18n (automatically mocked)
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest // WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
// No explicit mock needed - it returns translation keys as-is // No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required: // Override only if custom translations are required:
// jest.mock('react-i18next', () => ({ // vi.mock('react-i18next', () => ({
// useTranslation: () => ({ // useTranslation: () => ({
// t: (key: string) => { // t: (key: string) => {
// const customTranslations: Record<string, string> = { // const customTranslations: Record<string, string> = {
@ -43,17 +43,17 @@ import userEvent from '@testing-library/user-event'
// Router (if component uses useRouter, usePathname, useSearchParams) // Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior // WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = jest.fn() // const mockPush = vi.fn()
// jest.mock('next/navigation', () => ({ // vi.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }), // useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path', // usePathname: () => '/test-path',
// })) // }))
// API services (if component fetches data) // API services (if component fetches data)
// WHY: Prevents real network calls, enables testing all states (loading/success/error) // WHY: Prevents real network calls, enables testing all states (loading/success/error)
// jest.mock('@/service/api') // vi.mock('@/service/api')
// import * as api from '@/service/api' // import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api> // const mockedApi = vi.mocked(api)
// Shared mock state (for portal/dropdown components) // Shared mock state (for portal/dropdown components)
// WHY: Portal components like PortalToFollowElem need shared state between // WHY: Portal components like PortalToFollowElem need shared state between
@ -98,7 +98,7 @@ describe('ComponentName', () => {
// - Prevents mock call history from leaking between tests // - Prevents mock call history from leaking between tests
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes // - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
// Reset shared mock state if used (CRITICAL for portal/dropdown tests) // Reset shared mock state if used (CRITICAL for portal/dropdown tests)
// mockOpenState = false // mockOpenState = false
}) })
@ -155,7 +155,7 @@ describe('ComponentName', () => {
// - userEvent simulates real user behavior (focus, hover, then click) // - userEvent simulates real user behavior (focus, hover, then click)
// - fireEvent is lower-level, doesn't trigger all browser events // - fireEvent is lower-level, doesn't trigger all browser events
// const user = userEvent.setup() // const user = userEvent.setup()
// const handleClick = jest.fn() // const handleClick = vi.fn()
// render(<ComponentName onClick={handleClick} />) // render(<ComponentName onClick={handleClick} />)
// //
// await user.click(screen.getByRole('button')) // await user.click(screen.getByRole('button'))
@ -165,7 +165,7 @@ describe('ComponentName', () => {
it('should call onChange when value changes', async () => { it('should call onChange when value changes', async () => {
// const user = userEvent.setup() // const user = userEvent.setup()
// const handleChange = jest.fn() // const handleChange = vi.fn()
// render(<ComponentName onChange={handleChange} />) // render(<ComponentName onChange={handleChange} />)
// //
// await user.type(screen.getByRole('textbox'), 'new value') // await user.type(screen.getByRole('textbox'), 'new value')
@ -198,7 +198,7 @@ describe('ComponentName', () => {
}) })
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// Async Operations (if component fetches data - useSWR, useQuery, fetch) // Async Operations (if component fetches data - useQuery, fetch)
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// WHY: Async operations have 3 states users experience: loading, success, error // WHY: Async operations have 3 states users experience: loading, success, error
describe('Async Operations', () => { describe('Async Operations', () => {

View File

@ -15,9 +15,9 @@ import { renderHook, act, waitFor } from '@testing-library/react'
// ============================================================================ // ============================================================================
// API services (if hook fetches data) // API services (if hook fetches data)
// jest.mock('@/service/api') // vi.mock('@/service/api')
// import * as api from '@/service/api' // import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api> // const mockedApi = vi.mocked(api)
// ============================================================================ // ============================================================================
// Test Helpers // Test Helpers
@ -38,7 +38,7 @@ import { renderHook, act, waitFor } from '@testing-library/react'
describe('useHookName', () => { describe('useHookName', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
}) })
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
@ -145,7 +145,7 @@ describe('useHookName', () => {
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
describe('Side Effects', () => { describe('Side Effects', () => {
it('should call callback when value changes', () => { it('should call callback when value changes', () => {
// const callback = jest.fn() // const callback = vi.fn()
// const { result } = renderHook(() => useHookName({ onChange: callback })) // const { result } = renderHook(() => useHookName({ onChange: callback }))
// //
// act(() => { // act(() => {
@ -156,9 +156,9 @@ describe('useHookName', () => {
}) })
it('should cleanup on unmount', () => { it('should cleanup on unmount', () => {
// const cleanup = jest.fn() // const cleanup = vi.fn()
// jest.spyOn(window, 'addEventListener') // vi.spyOn(window, 'addEventListener')
// jest.spyOn(window, 'removeEventListener') // vi.spyOn(window, 'removeEventListener')
// //
// const { unmount } = renderHook(() => useHookName()) // const { unmount } = renderHook(() => useHookName())
// //

View File

@ -49,7 +49,7 @@ import userEvent from '@testing-library/user-event'
it('should submit form', async () => { it('should submit form', async () => {
const user = userEvent.setup() const user = userEvent.setup()
const onSubmit = jest.fn() const onSubmit = vi.fn()
render(<Form onSubmit={onSubmit} />) render(<Form onSubmit={onSubmit} />)
@ -77,15 +77,15 @@ it('should submit form', async () => {
```typescript ```typescript
describe('Debounced Search', () => { describe('Debounced Search', () => {
beforeEach(() => { beforeEach(() => {
jest.useFakeTimers() vi.useFakeTimers()
}) })
afterEach(() => { afterEach(() => {
jest.useRealTimers() vi.useRealTimers()
}) })
it('should debounce search input', async () => { it('should debounce search input', async () => {
const onSearch = jest.fn() const onSearch = vi.fn()
render(<SearchInput onSearch={onSearch} debounceMs={300} />) render(<SearchInput onSearch={onSearch} debounceMs={300} />)
// Type in the input // Type in the input
@ -95,7 +95,7 @@ describe('Debounced Search', () => {
expect(onSearch).not.toHaveBeenCalled() expect(onSearch).not.toHaveBeenCalled()
// Advance timers // Advance timers
jest.advanceTimersByTime(300) vi.advanceTimersByTime(300)
// Now search is called // Now search is called
expect(onSearch).toHaveBeenCalledWith('query') expect(onSearch).toHaveBeenCalledWith('query')
@ -107,8 +107,8 @@ describe('Debounced Search', () => {
```typescript ```typescript
it('should retry on failure', async () => { it('should retry on failure', async () => {
jest.useFakeTimers() vi.useFakeTimers()
const fetchData = jest.fn() const fetchData = vi.fn()
.mockRejectedValueOnce(new Error('Network error')) .mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ data: 'success' }) .mockResolvedValueOnce({ data: 'success' })
@ -120,7 +120,7 @@ it('should retry on failure', async () => {
}) })
// Advance timer for retry // Advance timer for retry
jest.advanceTimersByTime(1000) vi.advanceTimersByTime(1000)
// Second call succeeds // Second call succeeds
await waitFor(() => { await waitFor(() => {
@ -128,7 +128,7 @@ it('should retry on failure', async () => {
expect(screen.getByText('success')).toBeInTheDocument() expect(screen.getByText('success')).toBeInTheDocument()
}) })
jest.useRealTimers() vi.useRealTimers()
}) })
``` ```
@ -136,19 +136,19 @@ it('should retry on failure', async () => {
```typescript ```typescript
// Run all pending timers // Run all pending timers
jest.runAllTimers() vi.runAllTimers()
// Run only pending timers (not new ones created during execution) // Run only pending timers (not new ones created during execution)
jest.runOnlyPendingTimers() vi.runOnlyPendingTimers()
// Advance by specific time // Advance by specific time
jest.advanceTimersByTime(1000) vi.advanceTimersByTime(1000)
// Get current fake time // Get current fake time
jest.now() Date.now()
// Clear all timers // Clear all timers
jest.clearAllTimers() vi.clearAllTimers()
``` ```
## API Testing Patterns ## API Testing Patterns
@ -158,7 +158,7 @@ jest.clearAllTimers()
```typescript ```typescript
describe('DataFetcher', () => { describe('DataFetcher', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
}) })
it('should show loading state', () => { it('should show loading state', () => {
@ -241,7 +241,7 @@ it('should submit form and show success', async () => {
```typescript ```typescript
it('should fetch data on mount', async () => { it('should fetch data on mount', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' }) const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
render(<ComponentWithEffect fetchData={fetchData} />) render(<ComponentWithEffect fetchData={fetchData} />)
@ -255,7 +255,7 @@ it('should fetch data on mount', async () => {
```typescript ```typescript
it('should refetch when id changes', async () => { it('should refetch when id changes', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' }) const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />) const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
@ -276,8 +276,8 @@ it('should refetch when id changes', async () => {
```typescript ```typescript
it('should cleanup subscription on unmount', () => { it('should cleanup subscription on unmount', () => {
const subscribe = jest.fn() const subscribe = vi.fn()
const unsubscribe = jest.fn() const unsubscribe = vi.fn()
subscribe.mockReturnValue(unsubscribe) subscribe.mockReturnValue(unsubscribe)
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />) const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
@ -332,14 +332,14 @@ expect(description).toBeInTheDocument()
```typescript ```typescript
// Bad - fake timers don't work well with real Promises // Bad - fake timers don't work well with real Promises
jest.useFakeTimers() vi.useFakeTimers()
await waitFor(() => { await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument() expect(screen.getByText('Data')).toBeInTheDocument()
}) // May timeout! }) // May timeout!
// Good - use runAllTimers or advanceTimersByTime // Good - use runAllTimers or advanceTimersByTime
jest.useFakeTimers() vi.useFakeTimers()
render(<Component />) render(<Component />)
jest.runAllTimers() vi.runAllTimers()
expect(screen.getByText('Data')).toBeInTheDocument() expect(screen.getByText('Data')).toBeInTheDocument()
``` ```

View File

@ -74,9 +74,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Mocks ### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`) - [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`) - [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach` - [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations - [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
- [ ] Router mocks match actual Next.js API - [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior - [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs - [ ] Only mock: API services, complex context providers, third-party libs
@ -114,15 +114,15 @@ For the current file being tested:
**Run these checks after EACH test file, not just at the end:** **Run these checks after EACH test file, not just at the end:**
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file** - [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Fix any failures immediately - [ ] Fix any failures immediately
- [ ] Mark file as complete in todo list - [ ] Mark file as complete in todo list
- [ ] Only then proceed to next file - [ ] Only then proceed to next file
### After All Files Complete ### After All Files Complete
- [ ] Run full directory test: `pnpm test -- path/to/directory/` - [ ] Run full directory test: `pnpm test path/to/directory/`
- [ ] Check coverage report: `pnpm test -- --coverage` - [ ] Check coverage report: `pnpm test:coverage`
- [ ] Run `pnpm lint:fix` on all test files - [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo` - [ ] Run `pnpm type-check:tsgo`
@ -132,10 +132,10 @@ For the current file being tested:
```typescript ```typescript
// ❌ Mock doesn't match actual behavior // ❌ Mock doesn't match actual behavior
jest.mock('./Component', () => () => <div>Mocked</div>) vi.mock('./Component', () => () => <div>Mocked</div>)
// ✅ Mock matches actual conditional logic // ✅ Mock matches actual conditional logic
jest.mock('./Component', () => ({ isOpen }: any) => vi.mock('./Component', () => ({ isOpen }: any) =>
isOpen ? <div>Content</div> : null isOpen ? <div>Content</div> : null
) )
``` ```
@ -145,7 +145,7 @@ jest.mock('./Component', () => ({ isOpen }: any) =>
```typescript ```typescript
// ❌ Shared state not reset // ❌ Shared state not reset
let mockState = false let mockState = false
jest.mock('./useHook', () => () => mockState) vi.mock('./useHook', () => () => mockState)
// ✅ Reset in beforeEach // ✅ Reset in beforeEach
beforeEach(() => { beforeEach(() => {
@ -186,16 +186,16 @@ Always test these scenarios:
```bash ```bash
# Run specific test # Run specific test
pnpm test -- path/to/file.spec.tsx pnpm test path/to/file.spec.tsx
# Run with coverage # Run with coverage
pnpm test -- --coverage path/to/file.spec.tsx pnpm test:coverage path/to/file.spec.tsx
# Watch mode # Watch mode
pnpm test -- --watch path/to/file.spec.tsx pnpm test:watch path/to/file.spec.tsx
# Update snapshots (use sparingly) # Update snapshots (use sparingly)
pnpm test -- -u path/to/file.spec.tsx pnpm test -u path/to/file.spec.tsx
# Analyze component # Analyze component
pnpm analyze-component path/to/component.tsx pnpm analyze-component path/to/component.tsx

View File

@ -126,7 +126,7 @@ describe('Counter', () => {
describe('ControlledInput', () => { describe('ControlledInput', () => {
it('should call onChange with new value', async () => { it('should call onChange with new value', async () => {
const user = userEvent.setup() const user = userEvent.setup()
const handleChange = jest.fn() const handleChange = vi.fn()
render(<ControlledInput value="" onChange={handleChange} />) render(<ControlledInput value="" onChange={handleChange} />)
@ -136,7 +136,7 @@ describe('ControlledInput', () => {
}) })
it('should display controlled value', () => { it('should display controlled value', () => {
render(<ControlledInput value="controlled" onChange={jest.fn()} />) render(<ControlledInput value="controlled" onChange={vi.fn()} />)
expect(screen.getByRole('textbox')).toHaveValue('controlled') expect(screen.getByRole('textbox')).toHaveValue('controlled')
}) })
@ -195,7 +195,7 @@ describe('ItemList', () => {
it('should handle item selection', async () => { it('should handle item selection', async () => {
const user = userEvent.setup() const user = userEvent.setup()
const onSelect = jest.fn() const onSelect = vi.fn()
render(<ItemList items={items} onSelect={onSelect} />) render(<ItemList items={items} onSelect={onSelect} />)
@ -217,20 +217,20 @@ describe('ItemList', () => {
```typescript ```typescript
describe('Modal', () => { describe('Modal', () => {
it('should not render when closed', () => { it('should not render when closed', () => {
render(<Modal isOpen={false} onClose={jest.fn()} />) render(<Modal isOpen={false} onClose={vi.fn()} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument() expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
}) })
it('should render when open', () => { it('should render when open', () => {
render(<Modal isOpen={true} onClose={jest.fn()} />) render(<Modal isOpen={true} onClose={vi.fn()} />)
expect(screen.getByRole('dialog')).toBeInTheDocument() expect(screen.getByRole('dialog')).toBeInTheDocument()
}) })
it('should call onClose when clicking overlay', async () => { it('should call onClose when clicking overlay', async () => {
const user = userEvent.setup() const user = userEvent.setup()
const handleClose = jest.fn() const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />) render(<Modal isOpen={true} onClose={handleClose} />)
@ -241,7 +241,7 @@ describe('Modal', () => {
it('should call onClose when pressing Escape', async () => { it('should call onClose when pressing Escape', async () => {
const user = userEvent.setup() const user = userEvent.setup()
const handleClose = jest.fn() const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />) render(<Modal isOpen={true} onClose={handleClose} />)
@ -254,7 +254,7 @@ describe('Modal', () => {
const user = userEvent.setup() const user = userEvent.setup()
render( render(
<Modal isOpen={true} onClose={jest.fn()}> <Modal isOpen={true} onClose={vi.fn()}>
<button>First</button> <button>First</button>
<button>Second</button> <button>Second</button>
</Modal> </Modal>
@ -279,7 +279,7 @@ describe('Modal', () => {
describe('LoginForm', () => { describe('LoginForm', () => {
it('should submit valid form', async () => { it('should submit valid form', async () => {
const user = userEvent.setup() const user = userEvent.setup()
const onSubmit = jest.fn() const onSubmit = vi.fn()
render(<LoginForm onSubmit={onSubmit} />) render(<LoginForm onSubmit={onSubmit} />)
@ -296,7 +296,7 @@ describe('LoginForm', () => {
it('should show validation errors', async () => { it('should show validation errors', async () => {
const user = userEvent.setup() const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />) render(<LoginForm onSubmit={vi.fn()} />)
// Submit empty form // Submit empty form
await user.click(screen.getByRole('button', { name: /sign in/i })) await user.click(screen.getByRole('button', { name: /sign in/i }))
@ -308,7 +308,7 @@ describe('LoginForm', () => {
it('should validate email format', async () => { it('should validate email format', async () => {
const user = userEvent.setup() const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />) render(<LoginForm onSubmit={vi.fn()} />)
await user.type(screen.getByLabelText(/email/i), 'invalid-email') await user.type(screen.getByLabelText(/email/i), 'invalid-email')
await user.click(screen.getByRole('button', { name: /sign in/i })) await user.click(screen.getByRole('button', { name: /sign in/i }))
@ -318,7 +318,7 @@ describe('LoginForm', () => {
it('should disable submit button while submitting', async () => { it('should disable submit button while submitting', async () => {
const user = userEvent.setup() const user = userEvent.setup()
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100))) const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
render(<LoginForm onSubmit={onSubmit} />) render(<LoginForm onSubmit={onSubmit} />)
@ -407,7 +407,7 @@ it('test 1', () => {
// Good - cleanup is automatic with RTL, but reset mocks // Good - cleanup is automatic with RTL, but reset mocks
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
}) })
``` ```

View File

@ -23,7 +23,7 @@ import NodeConfigPanel from './node-config-panel'
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow' import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
// Mock workflow context // Mock workflow context
jest.mock('@/app/components/workflow/hooks', () => ({ vi.mock('@/app/components/workflow/hooks', () => ({
useWorkflowStore: () => mockWorkflowStore, useWorkflowStore: () => mockWorkflowStore,
useNodesInteractions: () => mockNodesInteractions, useNodesInteractions: () => mockNodesInteractions,
})) }))
@ -31,21 +31,21 @@ jest.mock('@/app/components/workflow/hooks', () => ({
let mockWorkflowStore = { let mockWorkflowStore = {
nodes: [], nodes: [],
edges: [], edges: [],
updateNode: jest.fn(), updateNode: vi.fn(),
} }
let mockNodesInteractions = { let mockNodesInteractions = {
handleNodeSelect: jest.fn(), handleNodeSelect: vi.fn(),
handleNodeDelete: jest.fn(), handleNodeDelete: vi.fn(),
} }
describe('NodeConfigPanel', () => { describe('NodeConfigPanel', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
mockWorkflowStore = { mockWorkflowStore = {
nodes: [], nodes: [],
edges: [], edges: [],
updateNode: jest.fn(), updateNode: vi.fn(),
} }
}) })
@ -161,23 +161,23 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event' import userEvent from '@testing-library/user-event'
import DocumentUploader from './document-uploader' import DocumentUploader from './document-uploader'
jest.mock('@/service/datasets', () => ({ vi.mock('@/service/datasets', () => ({
uploadDocument: jest.fn(), uploadDocument: vi.fn(),
parseDocument: jest.fn(), parseDocument: vi.fn(),
})) }))
import * as datasetService from '@/service/datasets' import * as datasetService from '@/service/datasets'
const mockedService = datasetService as jest.Mocked<typeof datasetService> const mockedService = vi.mocked(datasetService)
describe('DocumentUploader', () => { describe('DocumentUploader', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
}) })
describe('File Upload', () => { describe('File Upload', () => {
it('should accept valid file types', async () => { it('should accept valid file types', async () => {
const user = userEvent.setup() const user = userEvent.setup()
const onUpload = jest.fn() const onUpload = vi.fn()
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' }) mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
render(<DocumentUploader onUpload={onUpload} />) render(<DocumentUploader onUpload={onUpload} />)
@ -326,14 +326,14 @@ describe('DocumentList', () => {
describe('Search & Filtering', () => { describe('Search & Filtering', () => {
it('should filter by search query', async () => { it('should filter by search query', async () => {
const user = userEvent.setup() const user = userEvent.setup()
jest.useFakeTimers() vi.useFakeTimers()
render(<DocumentList datasetId="ds-1" />) render(<DocumentList datasetId="ds-1" />)
await user.type(screen.getByPlaceholderText(/search/i), 'test query') await user.type(screen.getByPlaceholderText(/search/i), 'test query')
// Debounce // Debounce
jest.advanceTimersByTime(300) vi.advanceTimersByTime(300)
await waitFor(() => { await waitFor(() => {
expect(mockedService.getDocuments).toHaveBeenCalledWith( expect(mockedService.getDocuments).toHaveBeenCalledWith(
@ -342,7 +342,7 @@ describe('DocumentList', () => {
) )
}) })
jest.useRealTimers() vi.useRealTimers()
}) })
}) })
}) })
@ -367,13 +367,13 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event' import userEvent from '@testing-library/user-event'
import AppConfigForm from './app-config-form' import AppConfigForm from './app-config-form'
jest.mock('@/service/apps', () => ({ vi.mock('@/service/apps', () => ({
updateAppConfig: jest.fn(), updateAppConfig: vi.fn(),
getAppConfig: jest.fn(), getAppConfig: vi.fn(),
})) }))
import * as appService from '@/service/apps' import * as appService from '@/service/apps'
const mockedService = appService as jest.Mocked<typeof appService> const mockedService = vi.mocked(appService)
describe('AppConfigForm', () => { describe('AppConfigForm', () => {
const defaultConfig = { const defaultConfig = {
@ -384,7 +384,7 @@ describe('AppConfigForm', () => {
} }
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
mockedService.getAppConfig.mockResolvedValue(defaultConfig) mockedService.getAppConfig.mockResolvedValue(defaultConfig)
}) })

View File

@ -19,8 +19,8 @@
```typescript ```typescript
// ❌ WRONG: Don't mock base components // ❌ WRONG: Don't mock base components
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>) vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>) vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components // ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading' import Loading from '@/app/components/base/loading'
@ -41,20 +41,23 @@ Only mock these categories:
| Location | Purpose | | Location | Purpose |
|----------|---------| |----------|---------|
| `web/__mocks__/` | Reusable mocks shared across multiple test files | | `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
| Test file | Test-specific mocks, inline with `jest.mock()` | | `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
## Essential Mocks ## Essential Mocks
### 1. i18n (Auto-loaded via Shared Mock) ### 1. i18n (Auto-loaded via Global Mock)
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest. A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
**No explicit mock needed** for most tests - it returns translation keys as-is. **No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock: For tests requiring custom translations, override the mock:
```typescript ```typescript
jest.mock('react-i18next', () => ({ vi.mock('react-i18next', () => ({
useTranslation: () => ({ useTranslation: () => ({
t: (key: string) => { t: (key: string) => {
const translations: Record<string, string> = { const translations: Record<string, string> = {
@ -69,15 +72,15 @@ jest.mock('react-i18next', () => ({
### 2. Next.js Router ### 2. Next.js Router
```typescript ```typescript
const mockPush = jest.fn() const mockPush = vi.fn()
const mockReplace = jest.fn() const mockReplace = vi.fn()
jest.mock('next/navigation', () => ({ vi.mock('next/navigation', () => ({
useRouter: () => ({ useRouter: () => ({
push: mockPush, push: mockPush,
replace: mockReplace, replace: mockReplace,
back: jest.fn(), back: vi.fn(),
prefetch: jest.fn(), prefetch: vi.fn(),
}), }),
usePathname: () => '/current-path', usePathname: () => '/current-path',
useSearchParams: () => new URLSearchParams('?key=value'), useSearchParams: () => new URLSearchParams('?key=value'),
@ -85,7 +88,7 @@ jest.mock('next/navigation', () => ({
describe('Component', () => { describe('Component', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
}) })
it('should navigate on click', () => { it('should navigate on click', () => {
@ -102,7 +105,7 @@ describe('Component', () => {
// ⚠️ Important: Use shared state for components that depend on each other // ⚠️ Important: Use shared state for components that depend on each other
let mockPortalOpenState = false let mockPortalOpenState = false
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, ...props }: any) => { PortalToFollowElem: ({ children, open, ...props }: any) => {
mockPortalOpenState = open || false // Update shared state mockPortalOpenState = open || false // Update shared state
return <div data-testid="portal" data-open={open}>{children}</div> return <div data-testid="portal" data-open={open}>{children}</div>
@ -119,7 +122,7 @@ jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
describe('Component', () => { describe('Component', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
mockPortalOpenState = false // ✅ Reset shared state mockPortalOpenState = false // ✅ Reset shared state
}) })
}) })
@ -130,13 +133,13 @@ describe('Component', () => {
```typescript ```typescript
import * as api from '@/service/api' import * as api from '@/service/api'
jest.mock('@/service/api') vi.mock('@/service/api')
const mockedApi = api as jest.Mocked<typeof api> const mockedApi = vi.mocked(api)
describe('Component', () => { describe('Component', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() vi.clearAllMocks()
// Setup default mock implementation // Setup default mock implementation
mockedApi.fetchData.mockResolvedValue({ data: [] }) mockedApi.fetchData.mockResolvedValue({ data: [] })
@ -239,32 +242,9 @@ describe('Component with Context', () => {
}) })
``` ```
### 7. SWR / React Query ### 7. React Query
```typescript ```typescript
// SWR
jest.mock('swr', () => ({
__esModule: true,
default: jest.fn(),
}))
import useSWR from 'swr'
const mockedUseSWR = useSWR as jest.Mock
describe('Component with SWR', () => {
it('should show loading state', () => {
mockedUseSWR.mockReturnValue({
data: undefined,
error: undefined,
isLoading: true,
})
render(<Component />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
})
// React Query
import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
const createTestQueryClient = () => new QueryClient({ const createTestQueryClient = () => new QueryClient({

View File

@ -35,7 +35,7 @@ When testing a **single component, hook, or utility**:
2. Run `pnpm analyze-component <path>` (if available) 2. Run `pnpm analyze-component <path>` (if available)
3. Check complexity score and features detected 3. Check complexity score and features detected
4. Write the test file 4. Write the test file
5. Run test: `pnpm test -- <file>.spec.tsx` 5. Run test: `pnpm test <file>.spec.tsx`
6. Fix any failures 6. Fix any failures
7. Verify coverage meets goals (100% function, >95% branch) 7. Verify coverage meets goals (100% function, >95% branch)
``` ```
@ -80,7 +80,7 @@ Process files in this recommended order:
``` ```
┌─────────────────────────────────────────────┐ ┌─────────────────────────────────────────────┐
│ 1. Write test file │ │ 1. Write test file │
│ 2. Run: pnpm test -- <file>.spec.tsx │ │ 2. Run: pnpm test <file>.spec.tsx
│ 3. If FAIL → Fix immediately, re-run │ │ 3. If FAIL → Fix immediately, re-run │
│ 4. If PASS → Mark complete in todo list │ │ 4. If PASS → Mark complete in todo list │
│ 5. ONLY THEN proceed to next file │ │ 5. ONLY THEN proceed to next file │
@ -95,10 +95,10 @@ After all individual tests pass:
```bash ```bash
# Run all tests in the directory together # Run all tests in the directory together
pnpm test -- path/to/directory/ pnpm test path/to/directory/
# Check coverage # Check coverage
pnpm test -- --coverage path/to/directory/ pnpm test:coverage path/to/directory/
``` ```
## Component Complexity Guidelines ## Component Complexity Guidelines
@ -201,9 +201,9 @@ Run pnpm test ← Multiple failures, hard to debug
``` ```
# GOOD: Incremental with verification # GOOD: Incremental with verification
Write component-a.spec.tsx Write component-a.spec.tsx
Run pnpm test -- component-a.spec.tsx ✅ Run pnpm test component-a.spec.tsx ✅
Write component-b.spec.tsx Write component-b.spec.tsx
Run pnpm test -- component-b.spec.tsx ✅ Run pnpm test component-b.spec.tsx ✅
...continue... ...continue...
``` ```

1
.codex/skills Symbolic link
View File

@ -0,0 +1 @@
../.claude/skills

View File

@ -6,6 +6,9 @@
"context": "..", "context": "..",
"dockerfile": "Dockerfile" "dockerfile": "Dockerfile"
}, },
"mounts": [
"source=dify-dev-tmp,target=/tmp,type=volume"
],
"features": { "features": {
"ghcr.io/devcontainers/features/node:1": { "ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true, "nodeGypDependencies": true,
@ -34,19 +37,13 @@
}, },
"postStartCommand": "./.devcontainer/post_start_command.sh", "postStartCommand": "./.devcontainer/post_start_command.sh",
"postCreateCommand": "./.devcontainer/post_create_command.sh" "postCreateCommand": "./.devcontainer/post_create_command.sh"
// Features to add to the dev container. More info: https://containers.dev/features. // Features to add to the dev container. More info: https://containers.dev/features.
// "features": {}, // "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally. // Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [], // "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created. // Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "python --version", // "postCreateCommand": "python --version",
// Configure tool-specific properties. // Configure tool-specific properties.
// "customizations": {}, // "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root" }
}

View File

@ -1,12 +1,13 @@
#!/bin/bash #!/bin/bash
WORKSPACE_ROOT=$(pwd) WORKSPACE_ROOT=$(pwd)
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
corepack enable corepack enable
cd web && pnpm install cd web && pnpm install
pipx install uv pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

313
.github/CODEOWNERS vendored
View File

@ -6,229 +6,244 @@
* @crazywoola @laipz8200 @Yeuoly * @crazywoola @laipz8200 @Yeuoly
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola
# Docs
/docs/ @crazywoola
# Backend (default owner, more specific rules below will override) # Backend (default owner, more specific rules below will override)
api/ @QuantumGhost /api/ @QuantumGhost
# Backend - MCP # Backend - MCP
api/core/mcp/ @Nov1c444 /api/core/mcp/ @Nov1c444
api/core/entities/mcp_provider.py @Nov1c444 /api/core/entities/mcp_provider.py @Nov1c444
api/services/tools/mcp_tools_manage_service.py @Nov1c444 /api/services/tools/mcp_tools_manage_service.py @Nov1c444
api/controllers/mcp/ @Nov1c444 /api/controllers/mcp/ @Nov1c444
api/controllers/console/app/mcp_server.py @Nov1c444 /api/controllers/console/app/mcp_server.py @Nov1c444
api/tests/**/*mcp* @Nov1c444 /api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine) # Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost /api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost /api/core/workflow/runtime/ @laipz8200 @QuantumGhost
api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph/ @laipz8200 @QuantumGhost
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
api/core/workflow/node_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost /api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ @Nov1c444 /api/core/workflow/nodes/agent/ @Nov1c444
api/core/workflow/nodes/iteration/ @Nov1c444 /api/core/workflow/nodes/iteration/ @Nov1c444
api/core/workflow/nodes/loop/ @Nov1c444 /api/core/workflow/nodes/loop/ @Nov1c444
api/core/workflow/nodes/llm/ @Nov1c444 /api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation) # Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong /api/core/rag/ @JohnJyong
api/services/rag_pipeline/ @JohnJyong /api/services/rag_pipeline/ @JohnJyong
api/services/dataset_service.py @JohnJyong /api/services/dataset_service.py @JohnJyong
api/services/knowledge_service.py @JohnJyong /api/services/knowledge_service.py @JohnJyong
api/services/external_knowledge_service.py @JohnJyong /api/services/external_knowledge_service.py @JohnJyong
api/services/hit_testing_service.py @JohnJyong /api/services/hit_testing_service.py @JohnJyong
api/services/metadata_service.py @JohnJyong /api/services/metadata_service.py @JohnJyong
api/services/vector_service.py @JohnJyong /api/services/vector_service.py @JohnJyong
api/services/entities/knowledge_entities/ @JohnJyong /api/services/entities/knowledge_entities/ @JohnJyong
api/services/entities/external_knowledge_entities/ @JohnJyong /api/services/entities/external_knowledge_entities/ @JohnJyong
api/controllers/console/datasets/ @JohnJyong /api/controllers/console/datasets/ @JohnJyong
api/controllers/service_api/dataset/ @JohnJyong /api/controllers/service_api/dataset/ @JohnJyong
api/models/dataset.py @JohnJyong /api/models/dataset.py @JohnJyong
api/tasks/rag_pipeline/ @JohnJyong /api/tasks/rag_pipeline/ @JohnJyong
api/tasks/add_document_to_index_task.py @JohnJyong /api/tasks/add_document_to_index_task.py @JohnJyong
api/tasks/batch_clean_document_task.py @JohnJyong /api/tasks/batch_clean_document_task.py @JohnJyong
api/tasks/clean_document_task.py @JohnJyong /api/tasks/clean_document_task.py @JohnJyong
api/tasks/clean_notion_document_task.py @JohnJyong /api/tasks/clean_notion_document_task.py @JohnJyong
api/tasks/document_indexing_task.py @JohnJyong /api/tasks/document_indexing_task.py @JohnJyong
api/tasks/document_indexing_sync_task.py @JohnJyong /api/tasks/document_indexing_sync_task.py @JohnJyong
api/tasks/document_indexing_update_task.py @JohnJyong /api/tasks/document_indexing_update_task.py @JohnJyong
api/tasks/duplicate_document_indexing_task.py @JohnJyong /api/tasks/duplicate_document_indexing_task.py @JohnJyong
api/tasks/recover_document_indexing_task.py @JohnJyong /api/tasks/recover_document_indexing_task.py @JohnJyong
api/tasks/remove_document_from_index_task.py @JohnJyong /api/tasks/remove_document_from_index_task.py @JohnJyong
api/tasks/retry_document_indexing_task.py @JohnJyong /api/tasks/retry_document_indexing_task.py @JohnJyong
api/tasks/sync_website_document_indexing_task.py @JohnJyong /api/tasks/sync_website_document_indexing_task.py @JohnJyong
api/tasks/batch_create_segment_to_index_task.py @JohnJyong /api/tasks/batch_create_segment_to_index_task.py @JohnJyong
api/tasks/create_segment_to_index_task.py @JohnJyong /api/tasks/create_segment_to_index_task.py @JohnJyong
api/tasks/delete_segment_from_index_task.py @JohnJyong /api/tasks/delete_segment_from_index_task.py @JohnJyong
api/tasks/disable_segment_from_index_task.py @JohnJyong /api/tasks/disable_segment_from_index_task.py @JohnJyong
api/tasks/disable_segments_from_index_task.py @JohnJyong /api/tasks/disable_segments_from_index_task.py @JohnJyong
api/tasks/enable_segment_to_index_task.py @JohnJyong /api/tasks/enable_segment_to_index_task.py @JohnJyong
api/tasks/enable_segments_to_index_task.py @JohnJyong /api/tasks/enable_segments_to_index_task.py @JohnJyong
api/tasks/clean_dataset_task.py @JohnJyong /api/tasks/clean_dataset_task.py @JohnJyong
api/tasks/deal_dataset_index_update_task.py @JohnJyong /api/tasks/deal_dataset_index_update_task.py @JohnJyong
api/tasks/deal_dataset_vector_index_task.py @JohnJyong /api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins # Backend - Plugins
api/core/plugin/ @Mairuis @Yeuoly @Stream29 /api/core/plugin/ @Mairuis @Yeuoly @Stream29
api/services/plugin/ @Mairuis @Yeuoly @Stream29 /api/services/plugin/ @Mairuis @Yeuoly @Stream29
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29 /api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29 /api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29 /api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook # Backend - Trigger/Schedule/Webhook
api/controllers/trigger/ @Mairuis @Yeuoly /api/controllers/trigger/ @Mairuis @Yeuoly
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly /api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly /api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
api/core/trigger/ @Mairuis @Yeuoly /api/core/trigger/ @Mairuis @Yeuoly
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly /api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
api/services/trigger/ @Mairuis @Yeuoly /api/services/trigger/ @Mairuis @Yeuoly
api/models/trigger.py @Mairuis @Yeuoly /api/models/trigger.py @Mairuis @Yeuoly
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly /api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly /api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly /api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/libs/schedule_utils.py @Mairuis @Yeuoly /api/libs/schedule_utils.py @Mairuis @Yeuoly
api/services/workflow/scheduler.py @Mairuis @Yeuoly /api/services/workflow/scheduler.py @Mairuis @Yeuoly
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly /api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly /api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly /api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly /api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly /api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly /api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly /api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly /api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly /api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly /api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow # Backend - Async Workflow
api/services/async_workflow_service.py @Mairuis @Yeuoly /api/services/async_workflow_service.py @Mairuis @Yeuoly
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly /api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing # Backend - Billing
api/services/billing_service.py @hj24 @zyssyz123 /api/services/billing_service.py @hj24 @zyssyz123
api/controllers/console/billing/ @hj24 @zyssyz123 /api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise # Backend - Enterprise
api/configs/enterprise/ @GarfieldDai @GareArc /api/configs/enterprise/ @GarfieldDai @GareArc
api/services/enterprise/ @GarfieldDai @GareArc /api/services/enterprise/ @GarfieldDai @GareArc
api/services/feature_service.py @GarfieldDai @GareArc /api/services/feature_service.py @GarfieldDai @GareArc
api/controllers/console/feature.py @GarfieldDai @GareArc /api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc /api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations # Backend - Database Migrations
api/migrations/ @snakevash @laipz8200 /api/migrations/ @snakevash @laipz8200 @MRZHUH
# Backend - Vector DB Middleware
/api/configs/middleware/vdb/* @JohnJyong
# Frontend # Frontend
web/ @iamjoel /web/ @iamjoel
# Frontend - Web Tests
/.github/workflows/web-tests.yml @iamjoel
# Frontend - App - Orchestration # Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh /web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh /web/app/components/workflow-app/ @iamjoel @zxhlyh
web/app/components/app/configuration/ @iamjoel @zxhlyh /web/app/components/app/configuration/ @iamjoel @zxhlyh
web/app/components/app/app-publisher/ @iamjoel @zxhlyh /web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat # Frontend - WebApp - Chat
web/app/components/base/chat/ @iamjoel @zxhlyh /web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion # Frontend - WebApp - Completion
web/app/components/share/text-generation/ @iamjoel @zxhlyh /web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation # Frontend - App - List and Creation
web/app/components/apps/ @JzoNgKVO @iamjoel /web/app/components/apps/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel /web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel /web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel /web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation # Frontend - App - API Documentation
web/app/components/develop/ @JzoNgKVO @iamjoel /web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations # Frontend - App - Logs and Annotations
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel /web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
web/app/components/app/log/ @JzoNgKVO @iamjoel /web/app/components/app/log/ @JzoNgKVO @iamjoel
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel /web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
web/app/components/app/annotation/ @JzoNgKVO @iamjoel /web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring # Frontend - App - Monitoring
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel /web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
web/app/components/app/overview/ @JzoNgKVO @iamjoel /web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings # Frontend - App - Settings
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel /web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing # Frontend - RAG - Hit Testing
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel /web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation # Frontend - RAG - List and Creation
web/app/components/datasets/list/ @iamjoel @WTW0313 /web/app/components/datasets/list/ @iamjoel @WTW0313
web/app/components/datasets/create/ @iamjoel @WTW0313 /web/app/components/datasets/create/ @iamjoel @WTW0313
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313 /web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313 /web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override) # Frontend - RAG - Orchestration (general rule first, specific rules below override)
web/app/components/rag-pipeline/ @iamjoel @WTW0313 /web/app/components/rag-pipeline/ @iamjoel @WTW0313
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh /web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh /web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List # Frontend - RAG - Documents List
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313 /web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313 /web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List # Frontend - RAG - Segments List
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313 /web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings # Frontend - RAG - Settings
web/app/components/datasets/settings/ @iamjoel @WTW0313 /web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins # Frontend - Ecosystem - Plugins
web/app/components/plugins/ @iamjoel @zhsama /web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools # Frontend - Ecosystem - Tools
web/app/components/tools/ @iamjoel @Yessenia-d /web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace # Frontend - Ecosystem - MarketPlace
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d /web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration # Frontend - Login and Registration
web/app/signin/ @douxc @iamjoel /web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel /web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel /web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel /web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel /web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel /web/app/forgot-password/ @douxc @iamjoel
web/app/account/ @douxc @iamjoel /web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication # Frontend - Service Authentication
web/service/base.ts @douxc @iamjoel /web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control # Frontend - WebApp Authentication and Access Control
web/app/(shareLayout)/components/ @douxc @iamjoel /web/app/(shareLayout)/components/ @douxc @iamjoel
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel /web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel /web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
web/app/components/app/app-access-control/ @douxc @iamjoel /web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page # Frontend - Explore Page
web/app/components/explore/ @CodingOnStar @iamjoel /web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings # Frontend - Personal Settings
web/app/components/header/account-setting/ @CodingOnStar @iamjoel /web/app/components/header/account-setting/ @CodingOnStar @iamjoel
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel /web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics # Frontend - Analytics
web/app/components/base/ga/ @CodingOnStar @iamjoel /web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components # Frontend - Base Components
web/app/components/base/ @iamjoel @zxhlyh /web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks # Frontend - Utils and Hooks
web/utils/classnames.ts @iamjoel @zxhlyh /web/utils/classnames.ts @iamjoel @zxhlyh
web/utils/time.ts @iamjoel @zxhlyh /web/utils/time.ts @iamjoel @zxhlyh
web/utils/format.ts @iamjoel @zxhlyh /web/utils/format.ts @iamjoel @zxhlyh
web/utils/clipboard.ts @iamjoel @zxhlyh /web/utils/clipboard.ts @iamjoel @zxhlyh
web/hooks/use-document-title.ts @iamjoel @zxhlyh /web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education # Frontend - Billing and Education
web/app/components/billing/ @iamjoel @zxhlyh /web/app/components/billing/ @iamjoel @zxhlyh
web/app/education-apply/ @iamjoel @zxhlyh /web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace # Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh /web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
# Docker
/docker/* @laipz8200

View File

@ -22,12 +22,12 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
- name: Setup UV and Python - name: Setup UV and Python
uses: astral-sh/setup-uv@v6 uses: astral-sh/setup-uv@v7
with: with:
enable-cache: true enable-cache: true
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
@ -57,7 +57,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox - name: Set up Sandbox
uses: hoverkraft-tech/compose-action@v2.0.2 uses: hoverkraft-tech/compose-action@v2
with: with:
compose-file: | compose-file: |
docker/docker-compose.middleware.yaml docker/docker-compose.middleware.yaml

View File

@ -12,12 +12,28 @@ jobs:
if: github.repository == 'langgenius/dify' if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v6
- name: Check Docker Compose inputs
id: docker-compose-changes
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v5 - uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.11"
- uses: astral-sh/setup-uv@v6 - uses: astral-sh/setup-uv@v7
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- run: | - run: |
cd api cd api
@ -66,27 +82,6 @@ jobs:
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat - name: mdformat
run: | run: |
uvx --python 3.13 mdformat . --exclude ".claude/skills/**" uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
- name: Web dependencies
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: oxlint
working-directory: ./web
run: pnpm exec oxlint --config .oxlintrc.json --fix .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@ -90,7 +90,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}" touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest - name: Upload digest
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v6
with: with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }} name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/* path: /tmp/digests/*

View File

@ -13,13 +13,13 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v6
with: with:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup UV and Python - name: Setup UV and Python
uses: astral-sh/setup-uv@v6 uses: astral-sh/setup-uv@v7
with: with:
enable-cache: true enable-cache: true
python-version: "3.12" python-version: "3.12"
@ -63,13 +63,13 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v6
with: with:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup UV and Python - name: Setup UV and Python
uses: astral-sh/setup-uv@v6 uses: astral-sh/setup-uv@v7
with: with:
enable-cache: true enable-cache: true
python-version: "3.12" python-version: "3.12"

View File

@ -27,7 +27,7 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }} vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }} migration-changed: ${{ steps.changes.outputs.migration }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v6
- uses: dorny/paths-filter@v3 - uses: dorny/paths-filter@v3
id: changes id: changes
with: with:
@ -38,6 +38,7 @@ jobs:
- '.github/workflows/api-tests.yml' - '.github/workflows/api-tests.yml'
web: web:
- 'web/**' - 'web/**'
- '.github/workflows/web-tests.yml'
vdb: vdb:
- 'api/core/rag/datasource/**' - 'api/core/rag/datasource/**'
- 'docker/**' - 'docker/**'

View File

@ -19,13 +19,13 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v46 uses: tj-actions/changed-files@v47
with: with:
files: | files: |
api/** api/**
@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python - name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@v6 uses: astral-sh/setup-uv@v7
with: with:
enable-cache: false enable-cache: false
python-version: "3.12" python-version: "3.12"
@ -68,15 +68,17 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v46 uses: tj-actions/changed-files@v47
with: with:
files: web/** files: |
web/**
.github/workflows/style.yml
- name: Install pnpm - name: Install pnpm
uses: pnpm/action-setup@v4 uses: pnpm/action-setup@v4
@ -85,12 +87,12 @@ jobs:
run_install: false run_install: false
- name: Setup NodeJS - name: Setup NodeJS
uses: actions/setup-node@v4 uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 22 node-version: 22
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies - name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
@ -108,50 +110,20 @@ jobs:
working-directory: ./web working-directory: ./web
run: pnpm run type-check:tsgo run: pnpm run type-check:tsgo
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter: superlinter:
name: SuperLinter name: SuperLinter
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v6
with: with:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v46 uses: tj-actions/changed-files@v47
with: with:
files: | files: |
**.sh **.sh

View File

@ -25,12 +25,12 @@ jobs:
working-directory: sdks/nodejs-client working-directory: sdks/nodejs-client
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }} - name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: ${{ matrix.node-version }} node-version: ${{ matrix.node-version }}
cache: '' cache: ''

View File

@ -1,10 +1,10 @@
name: Check i18n Files and Create PR name: Translate i18n Files Based on English
on: on:
push: push:
branches: [main] branches: [main]
paths: paths:
- 'web/i18n/en-US/*.ts' - 'web/i18n/en-US/*.json'
permissions: permissions:
contents: write contents: write
@ -18,7 +18,7 @@ jobs:
run: run:
working-directory: web working-directory: web
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v6
with: with:
fetch-depth: 0 fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
@ -28,13 +28,13 @@ jobs:
run: | run: |
git fetch origin "${{ github.event.before }}" || true git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts') changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files" echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args="" file_args=""
for file in $changed_files; do for file in $changed_files; do
filename=$(basename "$file" .ts) filename=$(basename "$file" .json)
file_args="$file_args --file $filename" file_args="$file_args --file $filename"
done done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
@ -51,11 +51,11 @@ jobs:
- name: Set up Node.js - name: Set up Node.js
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: 'lts/*' node-version: 'lts/*'
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies - name: Install dependencies
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
@ -65,12 +65,7 @@ jobs:
- name: Generate i18n translations - name: Generate i18n translations
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
working-directory: ./web working-directory: ./web
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
- name: Generate i18n type definitions
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run gen:i18n-types
- name: Create Pull Request - name: Create Pull Request
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
@ -78,14 +73,13 @@ jobs:
with: with:
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes' commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files and update type definitions' title: 'chore(i18n): translate i18n files based on en-US changes'
body: | body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale. This PR was automatically created to update i18n translation files based on changes in en-US locale.
**Triggered by:** ${{ github.sha }} **Triggered by:** ${{ github.sha }}
**Changes included:** **Changes included:**
- Updated translation files for all locales - Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates-${{ github.sha }} branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true delete-branch: true

View File

@ -19,19 +19,19 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
- name: Free Disk Space - name: Free Disk Space
uses: endersonmenezes/free-disk-space@v2 uses: endersonmenezes/free-disk-space@v3
with: with:
remove_dotnet: true remove_dotnet: true
remove_haskell: true remove_haskell: true
remove_tool_cache: true remove_tool_cache: true
- name: Setup UV and Python - name: Setup UV and Python
uses: astral-sh/setup-uv@v6 uses: astral-sh/setup-uv@v7
with: with:
enable-cache: true enable-cache: true
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@ -13,46 +13,356 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:
run: run:
shell: bash
working-directory: ./web working-directory: ./web
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: web/**
- name: Install pnpm - name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4 uses: pnpm/action-setup@v4
with: with:
package_json_file: web/package.json package_json_file: web/package.json
run_install: false run_install: false
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 22 node-version: 22
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies - name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests - name: Run tests
if: steps.changed-files.outputs.any_changed == 'true' run: pnpm test:coverage
working-directory: ./web
run: pnpm test - name: Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Vitest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v6
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error

12
.gitignore vendored
View File

@ -139,7 +139,6 @@ pyrightconfig.json
.idea/' .idea/'
.DS_Store .DS_Store
web/.vscode/settings.json
# Intellij IDEA Files # Intellij IDEA Files
.idea/* .idea/*
@ -196,6 +195,7 @@ docker/nginx/ssl/*
!docker/nginx/ssl/.gitkeep !docker/nginx/ssl/.gitkeep
docker/middleware.env docker/middleware.env
docker/docker-compose.override.yaml docker/docker-compose.override.yaml
docker/env-backup/*
sdks/python-client/build sdks/python-client/build
sdks/python-client/dist sdks/python-client/dist
@ -205,7 +205,6 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template !.vscode/launch.json.template
!.vscode/README.md !.vscode/README.md
api/.vscode api/.vscode
web/.vscode
# vscode Code History Extension # vscode Code History Extension
.history .history
@ -220,15 +219,6 @@ plugins.jsonl
# mise # mise
mise.toml mise.toml
# Next.js build output
.next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant # AI Assistant
.roo/ .roo/

View File

@ -1,34 +0,0 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -37,7 +37,7 @@
"-c", "-c",
"1", "1",
"-Q", "-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor", "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel", "--loglevel",
"INFO" "INFO"
], ],

View File

@ -116,6 +116,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names. # Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path ALIYUN_OSS_PATH=your-path
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Google Storage configuration # Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
@ -127,12 +128,14 @@ TENCENT_COS_SECRET_KEY=your-secret-key
TENCENT_COS_SECRET_ID=your-secret-id TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme TENCENT_COS_SCHEME=your-scheme
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
# Huawei OBS Storage Configuration # Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url HUAWEI_OBS_SERVER=your-server-url
HUAWEI_OBS_PATH_STYLE=false
# Baidu OBS Storage Configuration # Baidu OBS Storage Configuration
BAIDU_OBS_BUCKET_NAME=your-bucket-name BAIDU_OBS_BUCKET_NAME=your-bucket-name
@ -690,3 +693,7 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant # Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5 ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -84,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
``` ```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field( PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)", description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0, default=600.0,
) )
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
@ -1270,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
) )
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
@ -1295,6 +1310,7 @@ class FeatureConfig(
PositionConfig, PositionConfig,
RagEtlConfig, RagEtlConfig,
RepositoryConfig, RepositoryConfig,
SandboxExpiredRecordsCleanConfig,
SecurityConfig, SecurityConfig,
TenantIsolatedTaskQueueConfig, TenantIsolatedTaskQueueConfig,
ToolConfig, ToolConfig,

View File

@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Base path within the bucket to store objects (e.g., 'my-app-data/')", description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
default=None, default=None,
) )
ALIYUN_CLOUDBOX_ID: str | None = Field(
description="Cloudbox id for aliyun cloudbox service",
default=None,
)

View File

@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')", description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
default=None, default=None,
) )
HUAWEI_OBS_PATH_STYLE: bool = Field(
description="Flag to indicate whether to use path-style URLs for OBS requests",
default=False,
)

View File

@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings):
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'", description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
default=None, default=None,
) )
TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
description="Tencent Cloud COS custom domain setting",
default=None,
)

View File

@ -0,0 +1,57 @@
import os
from email.message import Message
from urllib.parse import quote
from flask import Response
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
HTML_EXTENSIONS = frozenset({"html", "htm"})
def _normalize_mime_type(mime_type: str | None) -> str:
if not mime_type:
return ""
message = Message()
message["Content-Type"] = mime_type
return message.get_content_type().strip().lower()
def _is_html_extension(extension: str | None) -> bool:
if not extension:
return False
return extension.lstrip(".").lower() in HTML_EXTENSIONS
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
normalized_mime_type = _normalize_mime_type(mime_type)
if normalized_mime_type in HTML_MIME_TYPES:
return True
if _is_html_extension(extension):
return True
if filename:
return _is_html_extension(os.path.splitext(filename)[1])
return False
def enforce_download_for_html(
response: Response,
*,
mime_type: str | None,
filename: str | None,
extension: str | None = None,
) -> bool:
if not is_html_content(mime_type, filename, extension):
return False
if filename:
encoded_filename = quote(filename)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
else:
response.headers["Content-Disposition"] = "attachment"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["X-Content-Type-Options"] = "nosniff"
return True

View File

@ -7,9 +7,9 @@ from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, extract_remote_ip, timezone from libs.helper import EmailStr, timezone
from models import AccountStatus from models import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -93,7 +93,6 @@ class ActivateApi(Resource):
"ActivationResponse", "ActivationResponse",
{ {
"result": fields.String(description="Operation result"), "result": fields.String(description="Operation result"),
"data": fields.Raw(description="Login token data"),
}, },
), ),
) )
@ -117,6 +116,4 @@ class ActivateApi(Resource):
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()
db.session.commit() db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) return {"result": "success"}
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -1,8 +1,9 @@
import base64 import base64
from typing import Literal
from flask import request from flask import request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel): class SubscriptionQuery(BaseModel):
plan: str = Field(..., description="Subscription plan") plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
interval: str = Field(..., description="Billing interval") interval: Literal["month", "year"] = Field(..., description="Billing interval")
@field_validator("plan")
@classmethod
def validate_plan(cls, value: str) -> str:
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
raise ValueError("Invalid plan")
return value
@field_validator("interval")
@classmethod
def validate_interval(cls, value: str) -> str:
if value not in {"month", "year"}:
raise ValueError("Invalid interval")
return value
class PartnerTenantsPayload(BaseModel): class PartnerTenantsPayload(BaseModel):

View File

@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None embedding_model: str | None = None
embedding_model_provider: str | None = None embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None retrieval_model: dict[str, Any] | None = None
partial_member_list: list[str] | None = None partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None external_knowledge_api_id: str | None = None

View File

@ -572,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
datasource_type=DatasourceType.NOTION, datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate( notion_info=NotionInfo.model_validate(
{ {
"credential_id": data_source_info["credential_id"], "credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"], "notion_page_type": data_source_info["type"],

View File

@ -40,7 +40,7 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel): class CompletionMessageExplorePayload(BaseModel):
inputs: dict[str, Any] inputs: dict[str, Any]
query: str = "" query: str = ""
files: list[dict[str, Any]] | None = None files: list[dict[str, Any]] | None = None
@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
raise ValueError("must be a valid UUID") from exc raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
# define completion api for user # define completion api for user
@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
endpoint="installed_app_completion", endpoint="installed_app_completion",
) )
class CompletionApi(InstalledAppResource): class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION: if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError() raise NotCompletionAppError()
payload = CompletionMessagePayload.model_validate(console_ns.payload or {}) payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True) args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming" streaming = payload.response_mode == "streaming"

View File

@ -1,5 +1,4 @@
from typing import Any from typing import Any
from uuid import UUID
from flask import request from flask import request
from flask_restx import marshal_with from flask_restx import marshal_with
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user from libs.login import current_user
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
@ -24,7 +24,7 @@ from .. import console_ns
class ConversationListQuery(BaseModel): class ConversationListQuery(BaseModel):
last_id: UUID | None = None last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100) limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None pinned: bool | None = None

View File

@ -2,7 +2,8 @@ import logging
from typing import Any from typing import Any
from flask import request from flask import request
from flask_restx import Resource, inputs, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from sqlalchemy import and_, select from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@ -18,6 +19,15 @@ from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
class InstalledAppCreatePayload(BaseModel):
app_id: str
class InstalledAppUpdatePayload(BaseModel):
is_pinned: bool | None = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id") payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == args["app_id"]).first() app = db.session.query(App).where(App.id == payload.app_id).first()
if app is None: if app is None:
raise NotFound("App not found") raise NotFound("App entity not found")
if not app.is_public: if not app.is_public:
raise Forbidden("You can't install a non-public app") raise Forbidden("You can't install a non-public app")
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.first() .first()
) )
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
recommended_app.install_count += 1 recommended_app.install_count += 1
new_installed_app = InstalledApp( new_installed_app = InstalledApp(
app_id=args["app_id"], app_id=payload.app_id,
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id, app_owner_tenant_id=app.tenant_id,
is_pinned=False, is_pinned=False,
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204 return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app): def patch(self, installed_app):
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean) payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
commit_args = False commit_args = False
if "is_pinned" in args: if payload.is_pinned is not None:
installed_app.is_pinned = args["is_pinned"] installed_app.is_pinned = payload.is_pinned
commit_args = True commit_args = True
if commit_args: if commit_args:

View File

@ -1,6 +1,5 @@
import logging import logging
from typing import Literal from typing import Literal
from uuid import UUID
from flask import request from flask import request
from flask_restx import marshal_with from flask_restx import marshal_with
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel): class MessageListQuery(BaseModel):
conversation_id: UUID conversation_id: UUIDStrOrEmpty
first_id: UUID | None = None first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100) limit: int = Field(default=20, ge=1, le=100)

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request from flask import request
from flask_restx import fields, marshal_with from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -10,19 +8,19 @@ from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel): class SavedMessageListQuery(BaseModel):
last_id: UUID | None = None last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100) limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel): class SavedMessageCreatePayload(BaseModel):
message_id: UUID message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,14 +1,32 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
class CodeBasedExtensionQuery(BaseModel):
module: str
class APIBasedExtensionPayload(BaseModel):
name: str = Field(description="Extension name")
api_endpoint: str = Field(description="API endpoint URL")
api_key: str = Field(description="API key for authentication")
register_schema_models(console_ns, APIBasedExtensionPayload)
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields) api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model)) api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m
class CodeBasedExtensionAPI(Resource): class CodeBasedExtensionAPI(Resource):
@console_ns.doc("get_code_based_extension") @console_ns.doc("get_code_based_extension")
@console_ns.doc(description="Get code-based extension data by module name") @console_ns.doc(description="Get code-based extension data by module name")
@console_ns.expect( @console_ns.doc(params={"module": "Extension module name"})
console_ns.parser().add_argument(
"module", type=str, required=True, location="args", help="Extension module name"
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args") query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
@console_ns.route("/api-based-extension") @console_ns.route("/api-based-extension")
@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource):
@console_ns.doc("create_api_based_extension") @console_ns.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension") @console_ns.doc(description="Create a new API-based extension")
@console_ns.expect( @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
console_ns.model(
"CreateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@console_ns.response(201, "Extension created successfully", api_based_extension_model) @console_ns.response(201, "Extension created successfully", api_based_extension_model)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_model) @marshal_with(api_based_extension_model)
def post(self): def post(self):
args = console_ns.payload payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension( extension_data = APIBasedExtension(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=payload.name,
api_endpoint=args["api_endpoint"], api_endpoint=payload.api_endpoint,
api_key=args["api_key"], api_key=payload.api_key,
) )
return APIBasedExtensionService.save(extension_data) return APIBasedExtensionService.save(extension_data)
@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("update_api_based_extension") @console_ns.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension") @console_ns.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"}) @console_ns.doc(params={"id": "Extension ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
console_ns.model(
"UpdateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@console_ns.response(200, "Extension updated successfully", api_based_extension_model) @console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@setup_required @setup_required
@login_required @login_required
@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
args = console_ns.payload payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
extension_data_from_db.name = args["name"] extension_data_from_db.name = payload.name
extension_data_from_db.api_endpoint = args["api_endpoint"] extension_data_from_db.api_endpoint = payload.api_endpoint
if args["api_key"] != HIDDEN_VALUE: if payload.api_key != HIDDEN_VALUE:
extension_data_from_db.api_key = args["api_key"] extension_data_from_db.api_key = payload.api_key
return APIBasedExtensionService.save(extension_data_from_db) return APIBasedExtensionService.save(extension_data_from_db)

View File

@ -1,31 +1,40 @@
from typing import Literal
from flask import request from flask import request
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import Tag
from services.tag_service import TagService from services.tag_service import TagService
def _validate_name(name): class TagBasePayload(BaseModel):
if not name or len(name) < 1 or len(name) > 50: name: str = Field(description="Tag name", min_length=1, max_length=50)
raise ValueError("Name must be between 1 to 50 characters.") type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
return name
parser_tags = ( class TagBindingPayload(BaseModel):
reqparse.RequestParser() tag_ids: list[str] = Field(description="Tag IDs to bind")
.add_argument( target_id: str = Field(description="Target ID to bind tags to")
"name", type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.", class TagBindingRemovePayload(BaseModel):
type=_validate_name, tag_id: str = Field(description="Tag ID to remove")
) target_id: str = Field(description="Target ID to unbind tag from")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
) )
@ -43,7 +52,7 @@ class TagListApi(Resource):
return tags, 200 return tags, 200
@console_ns.expect(parser_tags) @console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -53,22 +62,17 @@ class TagListApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = parser_tags.parse_args() payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(args) tag = TagService.save_tags(payload.model_dump())
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200 return response, 200
parser_tag_id = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
@console_ns.route("/tags/<uuid:tag_id>") @console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource): class TagUpdateDeleteApi(Resource):
@console_ns.expect(parser_tag_id) @console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = parser_tag_id.parse_args() payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(args, tag_id) tag = TagService.update_tags(payload.model_dump(), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource):
return 204 return 204
parser_create = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/create") @console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource): class TagBindingCreateApi(Resource):
@console_ns.expect(parser_create) @console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = parser_create.parse_args() payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(args) TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_remove = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/remove") @console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource): class TagBindingDeleteApi(Resource):
@console_ns.expect(parser_remove) @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = parser_remove.parse_args() payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(args) TagService.delete_tag_binding(payload.model_dump())
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -1,6 +1,8 @@
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -10,10 +12,20 @@ from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
class LoadBalancingCredentialPayload(BaseModel):
model: str
model_type: ModelType
credentials: dict[str, object]
register_schema_models(console_ns, LoadBalancingCredentialPayload)
@console_ns.route( @console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate" "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
) )
class LoadBalancingCredentialsValidateApi(Resource): class LoadBalancingCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
tenant_id = current_tenant_id tenant_id = current_tenant_id
parser = ( payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# validate model load balancing credentials # validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService() model_load_balancing_service = ModelLoadBalancingService()
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials( model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model=args["model"], model=payload.model,
model_type=args["model_type"], model_type=payload.model_type,
credentials=args["credentials"], credentials=payload.credentials,
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
result = False result = False
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate" "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
) )
class LoadBalancingConfigCredentialsValidateApi(Resource): class LoadBalancingConfigCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
tenant_id = current_tenant_id tenant_id = current_tenant_id
parser = ( payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# validate model load balancing config credentials # validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService() model_load_balancing_service = ModelLoadBalancingService()
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials( model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model=args["model"], model=payload.model,
model_type=args["model_type"], model_type=payload.model_type,
credentials=args["credentials"], credentials=payload.credentials,
config_id=config_id, config_id=config_id,
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:

View File

@ -1,5 +1,6 @@
import io import io
from typing import Literal from collections.abc import Mapping
from typing import Any, Literal
from flask import request, send_file from flask import request, send_file
from flask_restx import Resource from flask_restx import Resource
@ -141,6 +142,15 @@ class ParserDynamicOptions(BaseModel):
provider_type: Literal["tool", "trigger"] provider_type: Literal["tool", "trigger"]
class ParserDynamicOptionsWithCredentials(BaseModel):
plugin_id: str
provider: str
action: str
parameter: str
credential_id: str
credentials: Mapping[str, Any]
class PluginPermissionSettingsPayload(BaseModel): class PluginPermissionSettingsPayload(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
@ -183,6 +193,7 @@ reg(ParserGithubUpgrade)
reg(ParserUninstall) reg(ParserUninstall)
reg(ParserPermissionChange) reg(ParserPermissionChange)
reg(ParserDynamicOptions) reg(ParserDynamicOptions)
reg(ParserDynamicOptionsWithCredentials)
reg(ParserPreferencesChange) reg(ParserPreferencesChange)
reg(ParserExcludePlugin) reg(ParserExcludePlugin)
reg(ParserReadme) reg(ParserReadme)
@ -657,6 +668,37 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options}) return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options-with-credentials")
class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
@console_ns.expect(console_ns.models[ParserDynamicOptionsWithCredentials.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
"""Fetch dynamic options using credentials directly (for edit mode)."""
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
try:
options = PluginParameterService.get_dynamic_select_options_with_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args.plugin_id,
provider=args.provider,
action=args.action,
parameter=args.parameter,
credential_id=args.credential_id,
credentials=args.credentials,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change") @console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource): class PluginChangePreferencesApi(Resource):
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__]) @console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])

View File

@ -1,4 +1,5 @@
import io import io
import logging
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file from flask import make_response, redirect, request, send_file
@ -17,6 +18,7 @@ from controllers.console.wraps import (
is_admin_or_owner_required, is_admin_or_owner_required,
setup_required, setup_required,
) )
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
@ -39,6 +41,8 @@ from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService from services.tools.workflow_tools_manage_service import WorkflowToolManageService
logger = logging.getLogger(__name__)
def is_valid_url(url: str) -> bool: def is_valid_url(url: str) -> bool:
if not url: if not url:
@ -944,8 +948,8 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"]) configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider # 1) Create provider in a short transaction (no network I/O inside)
with Session(db.engine) as session, session.begin(): with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
result = service.create_provider( result = service.create_provider(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -960,7 +964,29 @@ class ToolProviderMCPApi(Resource):
configuration=configuration, configuration=configuration,
authentication=authentication, authentication=authentication,
) )
return jsonable_encoder(result)
# 2) Try to fetch tools immediately after creation so they appear without a second save.
# Perform network I/O outside any DB session to avoid holding locks.
try:
reconnect = MCPToolManageService.reconnect_with_url(
server_url=args["server_url"],
headers=args.get("headers") or {},
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
)
# Update just-created provider with authed/tools in a new short transaction
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
db_provider.authed = reconnect.authed
db_provider.tools = reconnect.tools
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
except Exception:
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
return jsonable_encoder(result)
@console_ns.expect(parser_mcp_put) @console_ns.expect(parser_mcp_put)
@setup_required @setup_required
@ -972,17 +998,23 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
# Step 1: Validate server URL change if needed (includes URL format validation and network operation) # Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_result = None validation_data = None
with Session(db.engine) as session: with Session(db.engine) as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change( validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] tenant_id=current_tenant_id, provider_id=args["provider_id"]
) )
# No need to check for errors here, exceptions will be raised directly # Step 2: Perform URL validation with network I/O OUTSIDE of any database session
# This prevents holding database locks during potentially slow network operations
validation_result = MCPToolManageService.validate_server_url_standalone(
tenant_id=current_tenant_id,
new_server_url=args["server_url"],
validation_data=validation_data,
)
# Step 2: Perform database update in a transaction # Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.update_provider( service.update_provider(
@ -999,7 +1031,8 @@ class ToolProviderMCPApi(Resource):
authentication=authentication, authentication=authentication,
validation_result=validation_result, validation_result=validation_result,
) )
return {"result": "success"}
return {"result": "success"}
@console_ns.expect(parser_mcp_delete) @console_ns.expect(parser_mcp_delete)
@setup_required @setup_required
@ -1012,7 +1045,8 @@ class ToolProviderMCPApi(Resource):
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
return {"result": "success"}
parser_auth = ( parser_auth = (

View File

@ -1,11 +1,15 @@
import logging import logging
from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.web.error import NotFoundError from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
@ -32,6 +36,32 @@ from ..wraps import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TriggerSubscriptionUpdateRequest(BaseModel):
"""Request payload for updating a trigger subscription"""
name: str | None = Field(default=None, description="The name for the subscription")
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
console_ns.schema_model(
TriggerSubscriptionUpdateRequest.__name__,
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
console_ns.schema_model(
TriggerSubscriptionVerifyRequest.__name__,
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon") @console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
class TriggerProviderIconApi(Resource): class TriggerProviderIconApi(Resource):
@setup_required @setup_required
@ -155,16 +185,16 @@ parser_api = (
@console_ns.route( @console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>", "/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
) )
class TriggerSubscriptionBuilderVerifyApi(Resource): class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
@console_ns.expect(parser_api) @console_ns.expect(parser_api)
@setup_required @setup_required
@login_required @login_required
@edit_permission_required @edit_permission_required
@account_initialization_required @account_initialization_required
def post(self, provider, subscription_builder_id): def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider""" """Verify and update a subscription instance for a trigger provider"""
user = current_user user = current_user
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
@ -289,6 +319,83 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
)
class TriggerSubscriptionUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, subscription_id: str):
"""Update a subscription instance"""
user = current_user
assert user.current_tenant_id is not None
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise NotFoundError(f"Subscription {subscription_id} not found")
provider_id = TriggerProviderID(subscription.provider_id)
try:
# rename only
if (
args.name is not None
and args.credentials is None
and args.parameters is None
and args.properties is None
):
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
)
return 200
# rebuild for create automatically by the provider
match subscription.credential_type:
case CredentialType.UNAUTHORIZED:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
properties=args.properties,
)
return 200
case CredentialType.API_KEY | CredentialType.OAUTH2:
if args.credentials:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in args.credentials.items()
}
else:
new_credentials = subscription.credentials
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=args.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=new_credentials,
parameters=args.parameters or subscription.parameters,
)
return 200
case _:
raise BadRequest("Invalid credential type")
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error updating subscription", exc_info=e)
raise
@console_ns.route( @console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete", "/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
) )
@ -576,3 +683,38 @@ class TriggerOAuthClientManageApi(Resource):
except Exception as e: except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e) logger.exception("Error removing OAuth client", exc_info=e)
raise raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
)
class TriggerSubscriptionVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_id):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
console_ns.payload
)
try:
result = TriggerProviderService.verify_subscription_credentials(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_id=subscription_id,
credentials=verify_request.credentials,
)
return result
except ValueError as e:
logger.warning("Credential verification failed", exc_info=e)
raise BadRequest(str(e)) from e
except Exception as e:
logger.exception("Error verifying subscription credentials", exc_info=e)
raise BadRequest(str(e)) from e

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
import services import services
from controllers.common.errors import UnsupportedFileTypeError from controllers.common.errors import UnsupportedFileTypeError
from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns from controllers.files import files_ns
from extensions.ext_database import db from extensions.ext_database import db
from services.account_service import TenantService from services.account_service import TenantService
@ -138,6 +139,13 @@ class FilePreviewApi(Resource):
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream" response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,
mime_type=upload_file.mime_type,
filename=upload_file.name,
extension=upload_file.extension,
)
return response return response

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError from controllers.common.errors import UnsupportedFileTypeError
from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
@ -78,4 +79,11 @@ class ToolFileApi(Resource):
encoded_filename = quote(tool_file.name) encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
enforce_download_for_html(
response,
mime_type=tool_file.mimetype,
filename=tool_file.name,
extension=extension,
)
return response return response

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from flask_restx._http import HTTPStatus from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel): class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255
)
@field_validator("variable_name", mode="before")
@classmethod
def validate_variable_name(cls, v: str | None) -> str | None:
"""
Validate variable_name to prevent injection attacks.
"""
if v is None:
return v
# Only allow safe characters: alphanumeric, underscore, hyphen, period
if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
raise ValueError(
"Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
)
# Prevent SQL injection patterns
dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
for pattern in dangerous_patterns:
if pattern in v.lower():
raise ValueError(f"Variable name contains invalid characters: {pattern}")
return v
class ConversationVariableUpdatePayload(BaseModel): class ConversationVariableUpdatePayload(BaseModel):
@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource):
try: try:
return ConversationService.get_conversational_variable( return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, query_args.limit, last_id app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
) )
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -5,6 +5,7 @@ from flask import Response, request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from controllers.common.file_response import enforce_download_for_html
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
@ -183,6 +184,13 @@ class FilePreviewApi(Resource):
# Override content-type for downloads to force download # Override content-type for downloads to force download
response.headers["Content-Type"] = "application/octet-stream" response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,
mime_type=upload_file.mime_type,
filename=upload_file.name,
extension=upload_file.extension,
)
# Add caching headers for performance # Add caching headers for performance
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour

View File

@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
DatasetApiResource, DatasetApiResource,
cloud_edition_billing_rate_limit_check, cloud_edition_billing_rate_limit_check,
validate_dataset_token,
) )
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
@ -49,7 +48,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None embedding_model: str | None = None
embedding_model_provider: str | None = None embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None retrieval_model: RetrievalModel | None = None
partial_member_list: list[str] | None = None partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None external_knowledge_api_id: str | None = None
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@validate_dataset_token
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _, dataset_id): def get(self, _):
"""Get all knowledge type tags.""" """Get all knowledge type tags."""
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
cid = current_user.current_tenant_id cid = current_user.current_tenant_id
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
} }
) )
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token def post(self, _):
def post(self, _, dataset_id):
"""Add a knowledge type tag.""" """Add a knowledge type tag."""
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
} }
) )
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token def patch(self, _):
def patch(self, _, dataset_id):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions", 403: "Forbidden - insufficient permissions",
} }
) )
@validate_dataset_token
@edit_permission_required @edit_permission_required
def delete(self, _, dataset_id): def delete(self, _):
"""Delete a knowledge type tag.""" """Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {}) payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id) TagService.delete_tag(payload.tag_id)
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions", 403: "Forbidden - insufficient permissions",
} }
) )
@validate_dataset_token def post(self, _):
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions", 403: "Forbidden - insufficient permissions",
} }
) )
@validate_dataset_token def post(self, _):
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@validate_dataset_token
def get(self, _, *args, **kwargs): def get(self, _, *args, **kwargs):
"""Get all knowledge type tags.""" """Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id") dataset_id = kwargs.get("dataset_id")

View File

@ -1,14 +1,13 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
from constants import HEADER_NAME_APP_CODE from constants import HEADER_NAME_APP_CODE
from controllers.common import fields from controllers.common import fields
from controllers.web import web_ns from controllers.common.schema import register_schema_models
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService from libs.passport import PassportService
from libs.token import extract_webapp_passport from libs.token import extract_webapp_passport
@ -18,9 +17,23 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
from . import web_ns
from .error import AppUnavailableError
from .wraps import WebApiResource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AppAccessModeQuery(BaseModel):
model_config = ConfigDict(populate_by_name=True)
app_id: str | None = Field(default=None, alias="appId", description="Application ID")
app_code: str | None = Field(default=None, alias="appCode", description="Application code")
register_schema_models(web_ns, AppAccessModeQuery)
@web_ns.route("/parameters") @web_ns.route("/parameters")
class AppParameterApi(WebApiResource): class AppParameterApi(WebApiResource):
"""Resource for app variables.""" """Resource for app variables."""
@ -96,21 +109,16 @@ class AppAccessMode(Resource):
} }
) )
def get(self): def get(self):
parser = ( raw_args = request.args.to_dict()
reqparse.RequestParser() args = AppAccessModeQuery.model_validate(raw_args)
.add_argument("appId", type=str, required=False, location="args")
.add_argument("appCode", type=str, required=False, location="args")
)
args = parser.parse_args()
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if not features.webapp_auth.enabled: if not features.webapp_auth.enabled:
return {"accessMode": "public"} return {"accessMode": "public"}
app_id = args.get("appId") app_id = args.app_id
if args.get("appCode"): if args.app_code:
app_code = args["appCode"] app_id = AppService.get_app_id_by_code(args.app_code)
app_id = AppService.get_app_id_by_code(app_code)
if not app_id: if not app_id:
raise ValueError("appId or appCode must be provided") raise ValueError("appId or appCode must be provided")

View File

@ -1,7 +1,8 @@
import logging import logging
from flask import request from flask import request
from flask_restx import fields, marshal_with, reqparse from flask_restx import fields, marshal_with
from pydantic import BaseModel, field_validator
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
@ -20,6 +21,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App from models.model import App
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
@ -29,6 +31,25 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError, UnsupportedAudioTypeServiceError,
) )
from ..common.schema import register_schema_models
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, TextToAudioPayload)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -88,6 +109,7 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio") @web_ns.route("/text-to-audio")
class TextApi(WebApiResource): class TextApi(WebApiResource):
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
@web_ns.doc("Text to Audio") @web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.") @web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc( @web_ns.doc(
@ -102,18 +124,11 @@ class TextApi(WebApiResource):
def post(self, app_model: App, end_user): def post(self, app_model: App, end_user):
"""Convert text to audio""" """Convert text to audio"""
try: try:
parser = ( payload = TextToAudioPayload.model_validate(web_ns.payload or {})
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None) message_id = payload.message_id
text = args.get("text", None) text = payload.text
voice = args.get("voice", None) voice = payload.voice
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
) )

View File

@ -1,9 +1,11 @@
import logging import logging
from typing import Any, Literal
from flask_restx import reqparse from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import ( from controllers.web.error import (
AppUnavailableError, AppUnavailableError,
@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the completion")
query: str = Field(default="", description="Query text for completion")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
retriever_from: str = Field(default="web_app", description="Source of retriever")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the chat")
query: str = Field(description="User query/message")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
retriever_from: str = Field(default="web_app", description="Source of retriever")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user # define completion api for user
@web_ns.route("/completion-messages") @web_ns.route("/completion-messages")
class CompletionApi(WebApiResource): class CompletionApi(WebApiResource):
@web_ns.doc("Create Completion Message") @web_ns.doc("Create Completion Message")
@web_ns.doc(description="Create a completion message for text generation applications.") @web_ns.doc(description="Create a completion message for text generation applications.")
@web_ns.doc( @web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
params={
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
"query": {"description": "Query text for completion", "type": "string", "required": False},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.doc( @web_ns.doc(
responses={ responses={
200: "Success", 200: "Success",
@ -67,18 +88,10 @@ class CompletionApi(WebApiResource):
if app_model.mode != AppMode.COMPLETION: if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError() raise NotCompletionAppError()
parser = ( payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
args = parser.parse_args() streaming = payload.response_mode == "streaming"
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
try: try:
@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource): class ChatApi(WebApiResource):
@web_ns.doc("Create Chat Message") @web_ns.doc("Create Chat Message")
@web_ns.doc(description="Create a chat message for conversational applications.") @web_ns.doc(description="Create a chat message for conversational applications.")
@web_ns.doc( @web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
params={
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
"query": {"description": "User query/message", "type": "string", "required": True},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.doc( @web_ns.doc(
responses={ responses={
200: "Success", 200: "Success",
@ -173,20 +171,10 @@ class ChatApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = ( payload = ChatMessagePayload.model_validate(web_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
args = parser.parse_args() streaming = payload.response_mode == "streaming"
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
try: try:

View File

@ -2,10 +2,12 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import ( from controllers.console.auth.error import (
AuthenticationFailedError, AuthenticationFailedError,
EmailCodeError, EmailCodeError,
@ -18,14 +20,40 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns from controllers.web import web_ns
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models import Account from models import Account
from services.account_service import AccountService from services.account_service import AccountService
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr
language: str | None = None
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
@web_ns.route("/forgot-password") @web_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordSendPayload.__name__])
@only_edition_enterprise @only_edition_enterprise
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@ -40,35 +68,31 @@ class ForgotPasswordSendEmailApi(Resource):
} }
) )
def post(self): def post(self):
parser = ( payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans": if payload.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
token = None token = None
if account is None: if account is None:
raise AuthenticationFailedError() raise AuthenticationFailedError()
else: else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
return {"result": "success", "data": token} return {"result": "success", "data": token}
@web_ns.route("/forgot-password/validity") @web_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordCheckPayload.__name__])
@only_edition_enterprise @only_edition_enterprise
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@ -78,45 +102,40 @@ class ForgotPasswordCheckApi(Resource):
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"} responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
) )
def post(self): def post(self):
parser = ( payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = payload.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
if is_forgot_password_error_rate_limit: if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError() raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"]) token_data = AccountService.get_reset_password_data(payload.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args["code"] != token_data.get("code"): if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"]) AccountService.add_forgot_password_error_rate_limit(payload.email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(payload.token)
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token( _, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"} user_email, code=payload.code, additional_data={"phase": "reset"}
) )
AccountService.reset_forgot_password_error_rate_limit(args["email"]) AccountService.reset_forgot_password_error_rate_limit(payload.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@web_ns.route("/forgot-password/resets") @web_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordResetPayload.__name__])
@only_edition_enterprise @only_edition_enterprise
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@ -131,20 +150,14 @@ class ForgotPasswordResetApi(Resource):
} }
) )
def post(self): def post(self):
parser = ( payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {})
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match # Validate passwords match
if args["new_password"] != args["password_confirm"]: if payload.new_password != payload.password_confirm:
raise PasswordMismatchError() raise PasswordMismatchError()
# Validate token and get reset data # Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"]) reset_data = AccountService.get_reset_password_data(payload.token)
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase # Must use token in reset phase
@ -152,11 +165,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
# Revoke token to prevent reuse # Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(payload.token)
# Generate secure salt and hash password # Generate secure salt and hash password
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt) password_hashed = hash_password(payload.new_password, salt)
email = reset_data.get("email", "") email = reset_data.get("email", "")
@ -170,7 +183,7 @@ class ForgotPasswordResetApi(Resource):
return {"result": "success"} return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session): def _update_existing_account(self, account: Account, password_hashed, salt, session):
# Update existing account credentials # Update existing account credentials
account.password = base64.b64encode(password_hashed).decode() account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode() account.password_salt = base64.b64encode(salt).decode()

View File

@ -1,9 +1,12 @@
import logging import logging
from typing import Literal
from flask_restx import fields, marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import ( from controllers.web.error import (
AppMoreLikeThisDisabledError, AppMoreLikeThisDisabledError,
@ -38,6 +41,33 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: str = Field(description="Conversation UUID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",
)
register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery)
@web_ns.route("/messages") @web_ns.route("/messages")
class MessageListApi(WebApiResource): class MessageListApi(WebApiResource):
message_fields = { message_fields = {
@ -69,7 +99,11 @@ class MessageListApi(WebApiResource):
@web_ns.doc( @web_ns.doc(
params={ params={
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
"first_id": {"description": "First message ID for pagination", "type": "string", "required": False}, "first_id": {
"description": "First message ID for pagination",
"type": "string",
"required": False,
},
"limit": { "limit": {
"description": "Number of messages to return (1-100)", "description": "Number of messages to return (1-100)",
"type": "integer", "type": "integer",
@ -94,17 +128,12 @@ class MessageListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = ( raw_args = request.args.to_dict()
reqparse.RequestParser() query = MessageListQuery.model_validate(raw_args)
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
try: try:
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] app_model, end_user, query.conversation_id, query.first_id, query.limit
) )
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -129,7 +158,7 @@ class MessageFeedbackApi(WebApiResource):
"enum": ["like", "dislike"], "enum": ["like", "dislike"],
"required": False, "required": False,
}, },
"content": {"description": "Feedback content/comment", "type": "string", "required": False}, "content": {"description": "Feedback content", "type": "string", "required": False},
} }
) )
@web_ns.doc( @web_ns.doc(
@ -146,20 +175,15 @@ class MessageFeedbackApi(WebApiResource):
def post(self, app_model, end_user, message_id): def post(self, app_model, end_user, message_id):
message_id = str(message_id) message_id = str(message_id)
parser = ( payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json", default=None)
)
args = parser.parse_args()
try: try:
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
user=end_user, user=end_user,
rating=args.get("rating"), rating=payload.rating,
content=args.get("content"), content=payload.content,
) )
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -171,17 +195,7 @@ class MessageFeedbackApi(WebApiResource):
class MessageMoreLikeThisApi(WebApiResource): class MessageMoreLikeThisApi(WebApiResource):
@web_ns.doc("Generate More Like This") @web_ns.doc("Generate More Like This")
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).") @web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
@web_ns.doc( @web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__])
params={
"message_id": {"description": "Message UUID", "type": "string", "required": True},
"response_mode": {
"description": "Response mode",
"type": "string",
"enum": ["blocking", "streaming"],
"required": True,
},
}
)
@web_ns.doc( @web_ns.doc(
responses={ responses={
200: "Success", 200: "Success",
@ -198,12 +212,10 @@ class MessageMoreLikeThisApi(WebApiResource):
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser().add_argument( raw_args = request.args.to_dict()
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" query = MessageMoreLikeThisQuery.model_validate(raw_args)
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = query.response_mode == "streaming"
try: try:
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(

View File

@ -1,7 +1,8 @@
import urllib.parse import urllib.parse
import httpx import httpx
from flask_restx import marshal_with, reqparse from flask_restx import marshal_with
from pydantic import BaseModel, Field, HttpUrl
import services import services
from controllers.common import helpers from controllers.common import helpers
@ -10,14 +11,23 @@ from controllers.common.errors import (
RemoteFileUploadError, RemoteFileUploadError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
from services.file_service import FileService from services.file_service import FileService
from ..common.schema import register_schema_models
from . import web_ns
from .wraps import WebApiResource
class RemoteFileUploadPayload(BaseModel):
url: HttpUrl = Field(description="Remote file URL")
register_schema_models(web_ns, RemoteFileUploadPayload)
@web_ns.route("/remote-files/<path:url>") @web_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(WebApiResource): class RemoteFileInfoApi(WebApiResource):
@ -97,10 +107,8 @@ class RemoteFileUploadApi(WebApiResource):
FileTooLargeError: File exceeds size limit FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported UnsupportedFileTypeError: File type not supported
""" """
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {})
args = parser.parse_args() url = str(payload.url)
url = args["url"]
try: try:
resp = ssrf_proxy.head(url=url) resp = ssrf_proxy.head(url=url)

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Sequence from collections.abc import Sequence
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Any, Literal from typing import Any, Literal
@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None) json_schema: str | None = Field(default=None)
@field_validator("description", mode="before") @field_validator("description", mode="before")
@classmethod @classmethod
@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
@field_validator("json_schema") @field_validator("json_schema")
@classmethod @classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: def validate_json_schema(cls, schema: str | None) -> str | None:
if schema is None: if schema is None:
return None return None
try: try:
Draft7Validator.check_schema(schema) json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
except SchemaError as e: except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}") raise ValueError(f"Invalid JSON schema: {e.message}")
return schema return schema

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final from typing import TYPE_CHECKING, Any, Union, final
@ -104,8 +105,9 @@ class BaseAppGenerator:
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST} variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
and not variable_entity.required and not variable_entity.required
): ):
# Treat empty string (frontend default) or empty list as unset # Treat empty string (frontend default) as unset
if not value and isinstance(value, (str, list)): # For FILE_LIST, allow empty list [] to pass through
if isinstance(value, str) and not value:
return None return None
if variable_entity.type in { if variable_entity.type in {
@ -175,6 +177,13 @@ class BaseAppGenerator:
value = True value = True
elif value == 0: elif value == 0:
value = False value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _: case _:
raise AssertionError("this statement should be unreachable.") raise AssertionError("this statement should be unreachable.")

View File

@ -90,6 +90,7 @@ class AppQueueManager:
""" """
self._clear_task_belong_cache() self._clear_task_belong_cache()
self._q.put(None) self._q.put(None)
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
def _clear_task_belong_cache(self) -> None: def _clear_task_belong_cache(self) -> None:
""" """

View File

@ -345,9 +345,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent): if isinstance(event, QueueLLMChunkEvent):
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
yield self._message_cycle_manager.message_to_stream_response( yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text), answer=cast(str, delta_text),
message_id=self._message_id, message_id=self._message_id,
event_type=event_type,
) )
else: else:
yield self._agent_message_to_stream_response( yield self._agent_message_to_stream_response(

View File

@ -5,7 +5,7 @@ from threading import Thread
from typing import Union from typing import Union
from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import select from sqlalchemy import exists, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
@ -54,6 +54,20 @@ class MessageCycleManager:
): ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._task_state = task_state self._task_state = task_state
self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent:
if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE
with Session(db.engine, expire_on_commit=False) as session:
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
if has_file:
self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE
return StreamEvent.MESSAGE
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None: def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
""" """
@ -224,6 +238,7 @@ class MessageCycleManager:
tool_arguments: str | None = None, tool_arguments: str | None = None,
tool_files: list[str] | None = None, tool_files: list[str] | None = None,
tool_error: str | None = None, tool_error: str | None = None,
event_type: StreamEvent | None = None,
) -> MessageStreamResponse: ) -> MessageStreamResponse:
""" """
Message to stream response. Message to stream response.
@ -238,22 +253,18 @@ class MessageCycleManager:
:param tool_error: error message if tool failed :param tool_error: error message if tool failed
:return: :return:
""" """
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse( return MessageStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=message_id, id=message_id,
answer=answer, answer=answer,
from_variable_selector=from_variable_selector, from_variable_selector=from_variable_selector,
event=event_type,
chunk_type=chunk_type, chunk_type=chunk_type,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
tool_name=tool_name, tool_name=tool_name,
tool_arguments=tool_arguments, tool_arguments=tool_arguments,
tool_files=tool_files, tool_files=tool_files,
tool_error=tool_error, tool_error=tool_error,
event=event_type or StreamEvent.MESSAGE,
) )
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@ -1,9 +1,14 @@
from collections.abc import Mapping
from textwrap import dedent from textwrap import dedent
from typing import Any
from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer): class Jinja2TemplateTransformer(TemplateTransformer):
# Use separate placeholder for base64-encoded template to avoid confusion
_template_b64_placeholder: str = "{{template_b64}}"
@classmethod @classmethod
def transform_response(cls, response: str): def transform_response(cls, response: str):
""" """
@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer):
""" """
return {"result": cls.extract_result_str_from_response(response)} return {"result": cls.extract_result_str_from_response(response)}
@classmethod
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
"""
Override base class to use base64 encoding for template code.
This prevents issues with special characters (quotes, newlines) in templates
breaking the generated Python script. Fixes #26818.
"""
script = cls.get_runner_script()
# Encode template as base64 to safely embed any content including quotes
code_b64 = cls.serialize_code(code)
script = script.replace(cls._template_b64_placeholder, code_b64)
inputs_str = cls.serialize_inputs(inputs)
script = script.replace(cls._inputs_placeholder, inputs_str)
return script
@classmethod @classmethod
def get_runner_script(cls) -> str: def get_runner_script(cls) -> str:
runner_script = dedent(f""" runner_script = dedent(f"""
# declare main function import jinja2
def main(**inputs):
import jinja2
template = jinja2.Template('''{cls._code_placeholder}''')
return template.render(**inputs)
import json import json
from base64 import b64decode from base64 import b64decode
# declare main function
def main(**inputs):
# Decode base64-encoded template to handle special characters safely
template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
template = jinja2.Template(template_code)
return template.render(**inputs)
# decode and prepare input dict # decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8')) inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))

View File

@ -13,6 +13,15 @@ class TemplateTransformer(ABC):
_inputs_placeholder: str = "{{inputs}}" _inputs_placeholder: str = "{{inputs}}"
_result_tag: str = "<<RESULT>>" _result_tag: str = "<<RESULT>>"
@classmethod
def serialize_code(cls, code: str) -> str:
"""
Serialize template code to base64 to safely embed in generated script.
This prevents issues with special characters like quotes breaking the script.
"""
code_bytes = code.encode("utf-8")
return b64encode(code_bytes).decode("utf-8")
@classmethod @classmethod
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]: def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
""" """

View File

@ -72,6 +72,22 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
) )
def _get_user_provided_host_header(headers: dict | None) -> str | None:
"""
Extract the user-provided Host header from the headers dict.
This is needed because when using a forward proxy, httpx may override the Host header.
We preserve the user's explicit Host header to support virtual hosting and other use cases.
"""
if not headers:
return None
# Case-insensitive lookup for Host header
for key, value in headers.items():
if key.lower() == "host":
return value
return None
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs: if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects") allow_redirects = kwargs.pop("allow_redirects")
@ -90,10 +106,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option) client = _get_ssrf_client(verify_option)
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0 retries = 0
while retries <= max_retries: while retries <= max_retries:
try: try:
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host
kwargs["headers"] = headers
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy # Check for SSRF protection by Squid proxy
if response.status_code in (401, 403): if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection # Check if this is a Squid SSRF rejection

View File

@ -1,56 +0,0 @@
import json
import logging
from typing import Any
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
logger = logging.getLogger(__name__)
class ToolProviderListCache:
"""Cache for tool provider lists"""
CACHE_TTL = 300 # 5 minutes
@staticmethod
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
"""Generate cache key for tool providers list"""
type_filter = typ or "all"
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
@staticmethod
@redis_fallback(default_return=None)
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
"""Get cached tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
cached_data = redis_client.get(cache_key)
if cached_data:
try:
return json.loads(cached_data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Failed to decode cached tool providers data")
return None
return None
@staticmethod
@redis_fallback()
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
"""Cache tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
@staticmethod
@redis_fallback()
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""Invalidate cache for tool providers"""
if typ:
# Invalidate specific type cache
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
keys = list(redis_client.scan_iter(pattern))
if keys:
redis_client.delete(*keys)

View File

@ -396,7 +396,7 @@ class IndexingRunner:
datasource_type=DatasourceType.NOTION, datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate( notion_info=NotionInfo.model_validate(
{ {
"credential_id": data_source_info["credential_id"], "credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"], "notion_page_type": data_source_info["type"],

View File

@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
""" """
Build a list of URLs to try for Protected Resource Metadata discovery. Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL. Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
Priority order:
1. URL from WWW-Authenticate header (if provided)
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
""" """
urls = [] urls = []
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
# Fallback: construct from server URL # Fallback: construct from server URL
parsed = urlparse(server_url) parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}" base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource") path = parsed.path.rstrip("/")
if fallback_url not in urls:
urls.append(fallback_url) # Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
if path:
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
if path_url not in urls:
urls.append(path_url)
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
root_url = f"{base_url}/.well-known/oauth-protected-resource"
if root_url not in urls:
urls.append(root_url)
return urls return urls
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery. Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3: Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server - OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path} - OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
Example: - OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
- issuer: https://example.com/oauth - OpenID Connect at root: https://example.com/.well-known/openid-configuration
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
""" """
urls = [] urls = []
base_url = auth_server_url or server_url base_url = auth_server_url or server_url
parsed = urlparse(base_url) parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}" base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash path = parsed.path.rstrip("/")
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
urls.append(f"{base}/.well-known/oauth-authorization-server")
# Try OpenID Connect discovery first (more common) # OpenID Connect Discovery at root
urls.append(urljoin(base + "/", ".well-known/openid-configuration")) urls.append(f"{base}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path: if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}")) # OpenID Connect Discovery with path insertion
else: urls.append(f"{base}/.well-known/openid-configuration{path}")
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
# OpenID Connect Discovery path appending
urls.append(f"{base}{path}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata with path insertion
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
return urls return urls

View File

@ -61,6 +61,7 @@ class SSETransport:
self.timeout = timeout self.timeout = timeout
self.sse_read_timeout = sse_read_timeout self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None self.endpoint_url: str | None = None
self.event_source: EventSource | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool: def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin. """Validate that the endpoint URL matches the connection origin.
@ -237,6 +238,9 @@ class SSETransport:
write_queue: WriteQueue = queue.Queue() write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue() status_queue: StatusQueue = queue.Queue()
# Store event_source for graceful shutdown
self.event_source = event_source
# Start SSE reader thread # Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue) executor.submit(self.sse_reader, event_source, read_queue, status_queue)
@ -296,6 +300,13 @@ def sse_client(
logger.exception("Error connecting to SSE endpoint") logger.exception("Error connecting to SSE endpoint")
raise raise
finally: finally:
# Close the SSE connection to unblock the reader thread
if transport.event_source is not None:
try:
transport.event_source.response.close()
except RuntimeError:
pass
# Clean up queues # Clean up queues
if read_queue: if read_queue:
read_queue.put(None) read_queue.put(None)

View File

@ -8,6 +8,7 @@ and session management.
import logging import logging
import queue import queue
import threading
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
@ -103,6 +104,9 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON, CONTENT_TYPE: JSON,
**self.headers, **self.headers,
} }
self.stop_event = threading.Event()
self._active_responses: list[httpx.Response] = []
self._lock = threading.Lock()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available.""" """Update headers with session ID if available."""
@ -111,6 +115,30 @@ class StreamableHTTPTransport:
headers[MCP_SESSION_ID] = self.session_id headers[MCP_SESSION_ID] = self.session_id
return headers return headers
def _register_response(self, response: httpx.Response):
"""Register a response for cleanup on shutdown."""
with self._lock:
self._active_responses.append(response)
def _unregister_response(self, response: httpx.Response):
"""Unregister a response after it's closed."""
with self._lock:
try:
self._active_responses.remove(response)
except ValueError as e:
logger.debug("Ignoring error during response unregister: %s", e)
def close_active_responses(self):
"""Close all active SSE connections to unblock threads."""
with self._lock:
responses_to_close = list(self._active_responses)
self._active_responses.clear()
for response in responses_to_close:
try:
response.close()
except RuntimeError as e:
logger.debug("Ignoring error during active response close: %s", e)
def _is_initialization_request(self, message: JSONRPCMessage) -> bool: def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request.""" """Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
@ -195,11 +223,21 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("GET SSE connection established") logger.debug("GET SSE connection established")
for sse in event_source.iter_sse(): # Register response for cleanup
self._handle_sse_event(sse, server_to_client_queue) self._register_response(event_source.response)
try:
for sse in event_source.iter_sse():
if self.stop_event.is_set():
logger.debug("GET stream received stop signal")
break
self._handle_sse_event(sse, server_to_client_queue)
finally:
self._unregister_response(event_source.response)
except Exception as exc: except Exception as exc:
logger.debug("GET stream error (non-fatal): %s", exc) if not self.stop_event.is_set():
logger.debug("GET stream error (non-fatal): %s", exc)
def _handle_resumption_request(self, ctx: RequestContext): def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE.""" """Handle a resumption request using GET with SSE."""
@ -224,15 +262,24 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established") logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse(): # Register response for cleanup
is_complete = self._handle_sse_event( self._register_response(event_source.response)
sse,
ctx.server_to_client_queue, try:
original_request_id, for sse in event_source.iter_sse():
ctx.metadata.on_resumption_token_update if ctx.metadata else None, if self.stop_event.is_set():
) logger.debug("Resumption stream received stop signal")
if is_complete: break
break is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
finally:
self._unregister_response(event_source.response)
def _handle_post_request(self, ctx: RequestContext): def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing.""" """Handle a POST request with response processing."""
@ -266,17 +313,20 @@ class StreamableHTTPTransport:
if is_initialization: if is_initialization:
self._maybe_extract_session_id_from_response(response) self._maybe_extract_session_id_from_response(response)
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower()) # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
if content_type.startswith(JSON): if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue) self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE): elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx) self._handle_sse_response(response, ctx)
else: else:
self._handle_unexpected_content_type( self._handle_unexpected_content_type(
content_type, content_type,
ctx.server_to_client_queue, ctx.server_to_client_queue,
) )
def _handle_json_response( def _handle_json_response(
self, self,
@ -295,17 +345,27 @@ class StreamableHTTPTransport:
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext): def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server.""" """Handle SSE response from the server."""
try: try:
# Register response for cleanup
self._register_response(response)
event_source = EventSource(response) event_source = EventSource(response)
for sse in event_source.iter_sse(): try:
is_complete = self._handle_sse_event( for sse in event_source.iter_sse():
sse, if self.stop_event.is_set():
ctx.server_to_client_queue, logger.debug("SSE response stream received stop signal")
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), break
) is_complete = self._handle_sse_event(
if is_complete: sse,
break ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
finally:
self._unregister_response(response)
except Exception as e: except Exception as e:
ctx.server_to_client_queue.put(e) if not self.stop_event.is_set():
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type( def _handle_unexpected_content_type(
self, self,
@ -345,6 +405,11 @@ class StreamableHTTPTransport:
""" """
while True: while True:
try: try:
# Check if we should stop
if self.stop_event.is_set():
logger.debug("Post writer received stop signal")
break
# Read message from client queue with timeout to check stop_event periodically # Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None: if session_message is None:
@ -381,7 +446,8 @@ class StreamableHTTPTransport:
except queue.Empty: except queue.Empty:
continue continue
except Exception as exc: except Exception as exc:
server_to_client_queue.put(exc) if not self.stop_event.is_set():
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client): def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request.""" """Terminate the session by sending a DELETE request."""
@ -465,6 +531,12 @@ def streamablehttp_client(
transport.get_session_id, transport.get_session_id,
) )
finally: finally:
# Set stop event to signal all threads to stop
transport.stop_event.set()
# Close all active SSE connections to unblock threads
transport.close_active_responses()
if transport.session_id and terminate_on_close: if transport.session_id and terminate_on_close:
transport.terminate_session(client) transport.terminate_session(client)

View File

@ -59,7 +59,7 @@ class MCPClient:
try: try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name) logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse") self.connect_server(sse_client, "sse")
except MCPConnectionError: except (MCPConnectionError, ValueError):
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp") self.connect_server(streamablehttp_client, "mcp")

View File

@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model
- Model provider display - Model provider display
![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png) Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
- Selectable model list display - Selectable model list display
![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png)
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
In addition, this list also returns configurable parameter information and rules for LLM, as shown below: In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png)
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
- Provider/model credential authentication - Provider/model credential authentication
![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png) The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png)
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
## Structure ## Structure
![](./docs/en_US/images/index/image-20231210165243632.png)
Model Runtime is divided into three layers: Model Runtime is divided into three layers:
- The outermost layer is the factory method - The outermost layer is the factory method
@ -60,9 +46,6 @@ Model Runtime is divided into three layers:
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
## Next Steps ## Documentation
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md) For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
- Implement interface methods: [Link](./docs/en_US/interfaces.md)

View File

@ -18,34 +18,20 @@
- 模型供应商展示 - 模型供应商展示
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png) 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
- 可选择的模型列表展示 - 可选择的模型列表展示
![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png) 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png)
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
- 供应商/模型凭据鉴权 - 供应商/模型凭据鉴权
![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png) 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png)
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO上图 2 为模型凭据 DEMO。
## 结构 ## 结构
![](./docs/zh_Hans/images/index/image-20231210165243632.png)
Model Runtime 分三层: Model Runtime 分三层:
- 最外层为工厂方法 - 最外层为工厂方法
@ -59,8 +45,7 @@ Model Runtime 分三层:
对于供应商/模型凭据,有两种情况 对于供应商/模型凭据,有两种情况
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
![Alt text](docs/zh_Hans/images/index/image.png)
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
@ -74,20 +59,6 @@ Model Runtime 分三层:
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
## 下一步 ## 文档
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md) 有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。
当添加后,这里将会出现一个新的供应商
![Alt text](docs/zh_Hans/images/index/image-1.png)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。

View File

@ -54,7 +54,7 @@ def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_
generate dotted_order for langsmith generate dotted_order for langsmith
""" """
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z" timestamp = start_time.strftime("%Y%m%dT%H%M%S%f") + "Z"
current_segment = f"{timestamp}{run_id}" current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None: if parent_dotted_order is None:

View File

@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None template: PluginParameterTemplate | None = None
required: bool = False required: bool = False
default: Union[float, int, str, bool] | None = None default: Union[float, int, str, bool, list, dict] | None = None
min: Union[float, int] | None = None min: Union[float, int] | None = None
max: Union[float, int] | None = None max: Union[float, int] | None = None
precision: int | None = None precision: int | None = None

View File

@ -39,7 +39,7 @@ from core.trigger.errors import (
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast( _plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None, float | httpx.Timeout | None,
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0), getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
) )
plugin_daemon_request_timeout: httpx.Timeout | None plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None: if _plugin_daemon_timeout_config is None:

View File

@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = [] documents = []
segment_query_stmt = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
)
if document_ids_filter:
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
segments = db.session.execute(segment_query_stmt).scalars().all()
segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices: for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).where( segment = segment_map.get(chunk_index)
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
if segment: if segment:
documents.append( documents.append(

View File

@ -1,4 +1,5 @@
import concurrent.futures import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
@ -7,12 +8,13 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, load_only from sqlalchemy.orm import Session, load_only
from configs import dify_config from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.index_type import IndexStructureType
@ -35,6 +37,8 @@ default_retrieval_model = {
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }
logger = logging.getLogger(__name__)
class RetrievalService: class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation # Cache precompiled regular expressions to avoid repeated compilation
@ -105,7 +109,12 @@ class RetrievalService:
) )
) )
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) if futures:
for future in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()
break
if exceptions: if exceptions:
raise ValueError(";\n".join(exceptions)) raise ValueError(";\n".join(exceptions))
@ -138,37 +147,47 @@ class RetrievalService:
@classmethod @classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search.""" """Deduplicate documents in O(n) while preserving first-seen order.
Rules:
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
- For non-dify documents (or dify without doc_id): deduplicate by content key
(provider, page_content), keeping the first occurrence.
"""
if not documents: if not documents:
return documents return documents
unique_documents = [] # Map of dedup key -> chosen Document
seen_doc_ids = set() chosen: dict[tuple, Document] = {}
# Preserve the order of first appearance of each dedup key
order: list[tuple] = []
for document in documents: for doc in documents:
# For dify provider documents, use doc_id for deduplication is_dify = doc.provider == "dify"
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata: doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
doc_id = document.metadata["doc_id"]
if doc_id not in seen_doc_ids: if is_dify and doc_id:
seen_doc_ids.add(doc_id) key = ("dify", doc_id)
unique_documents.append(document) if key not in chosen:
# If duplicate, keep the one with higher score chosen[key] = doc
elif "score" in document.metadata: order.append(key)
# Find existing document with same doc_id and compare scores else:
for i, existing_doc in enumerate(unique_documents): # Only replace if the new one has a score and it's strictly higher
if ( if "score" in doc.metadata:
existing_doc.metadata new_score = float(doc.metadata.get("score", 0.0))
and existing_doc.metadata.get("doc_id") == doc_id old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0) if new_score > old_score:
): chosen[key] = doc
unique_documents[i] = document
break
else: else:
# For non-dify documents, use content-based deduplication # Content-based dedup for non-dify or dify without doc_id
if document not in unique_documents: content_key = (doc.provider or "dify", doc.page_content)
unique_documents.append(document) if content_key not in chosen:
chosen[content_key] = doc
order.append(content_key)
# If duplicate content appears, we keep the first occurrence (no score comparison)
return unique_documents return [chosen[k] for k in order]
@classmethod @classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None: def _get_dataset(cls, dataset_id: str) -> Dataset | None:
@ -199,6 +218,7 @@ class RetrievalService:
) )
all_documents.extend(documents) all_documents.extend(documents)
except Exception as e: except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e)) exceptions.append(str(e))
@classmethod @classmethod
@ -292,6 +312,7 @@ class RetrievalService:
else: else:
all_documents.extend(documents) all_documents.extend(documents)
except Exception as e: except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e)) exceptions.append(str(e))
@classmethod @classmethod
@ -340,6 +361,7 @@ class RetrievalService:
else: else:
all_documents.extend(documents) all_documents.extend(documents)
except Exception as e: except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e)) exceptions.append(str(e))
@staticmethod @staticmethod
@ -370,171 +392,176 @@ class RetrievalService:
records = [] records = []
include_segment_ids = set() include_segment_ids = set()
segment_child_map = {} segment_child_map = {}
segment_file_map = {}
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id] valid_dataset_documents = {}
if not dataset_document: image_doc_ids: list[Any] = []
continue child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset_document = dataset_documents[document_id]
# Handle parent-child documents if not dataset_document:
if document.metadata.get("doc_type") == DocType.IMAGE: continue
attachment_info_dict = cls.get_segment_attachment_info( valid_dataset_documents[document_id] = dataset_document
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
if not child_chunk: if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
continue doc_id = document.metadata.get("doc_id") or ""
segment_id = child_chunk.segment_id doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
if not segment_id: image_doc_ids.append(doc_id)
continue
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.first()
)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
if attachment_info:
if segment.id in segment_file_map:
segment_file_map[segment.id].append(attachment_info)
else:
segment_file_map[segment.id] = [attachment_info]
else: else:
# Handle normal documents child_index_node_ids.append(doc_id)
segment = None else:
if document.metadata.get("doc_type") == DocType.IMAGE: doc_id = document.metadata.get("doc_id") or ""
attachment_info_dict = cls.get_segment_attachment_info( doc_to_document_map[doc_id] = document
dataset_document.dataset_id, if document.metadata.get("doc_type") == DocType.IMAGE:
dataset_document.tenant_id, image_doc_ids.append(doc_id)
document.metadata.get("doc_id") or "", else:
session, index_node_ids.append(doc_id)
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = session.scalar(document_segment_stmt)
if not segment: image_doc_ids = [i for i in image_doc_ids if i]
continue child_index_node_ids = [i for i in child_index_node_ids if i]
if segment.id not in include_segment_ids: index_node_ids = [i for i in index_node_ids if i]
include_segment_ids.add(segment.id)
record = { segment_ids: list[str] = []
"segment": segment, index_node_segments: list[DocumentSegment] = []
"score": document.metadata.get("score"), # type: ignore segments: list[DocumentSegment] = []
attachment_map: dict[str, list[dict[str, Any]]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
for attachment in attachments:
segment_ids.append(attachment["segment_id"])
if attachment["segment_id"] in attachment_map:
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
else:
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
if attachment["segment_id"] in doc_segment_map:
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
else:
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
segment_ids.append(i.segment_id)
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)
else:
child_chunk_map[i.segment_id] = [i]
if i.segment_id in doc_segment_map:
doc_segment_map[i.segment_id].append(i.index_node_id)
else:
doc_segment_map[i.segment_id] = [i.index_node_id]
if index_node_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id.in_(index_node_ids),
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id.in_(segment_ids),
)
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
if index_node_segments:
segments.extend(index_node_segments)
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunks or attachment_infos:
child_chunk_details = []
max_score = 0.0
for child_chunk in child_chunks:
document = doc_to_document_map[child_chunk.index_node_id]
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0) if document else 0.0,
}
child_chunk_details.append(child_chunk_detail)
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
for attachment_info in attachment_infos:
file_document = doc_to_document_map[attachment_info["id"]]
max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
map_detail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
} }
if attachment_info: segment_child_map[segment.id] = map_detail
segment_file_map[segment.id] = [attachment_info] record: dict[str, Any] = {
records.append(record) "segment": segment,
else: }
if attachment_info: records.append(record)
attachment_infos = segment_file_map.get(segment.id, []) else:
if attachment_info not in attachment_infos: if segment.id not in include_segment_ids:
attachment_infos.append(attachment_info) include_segment_ids.add(segment.id)
segment_file_map[segment.id] = attachment_infos max_score = 0.0
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
for attachment_info in attachment_infos:
file_doc = doc_to_document_map.get(attachment_info["id"])
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
"segment": segment,
"score": max_score,
}
records.append(record)
# Add child chunks information to records # Add child chunks information to records
for record in records: for record in records:
if record["segment"].id in segment_child_map: if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
if record["segment"].id in segment_file_map: if record["segment"].id in attachment_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
result = [] result: list[RetrievalSegments] = []
for record in records: for record in records:
# Extract segment # Extract segment
segment = record["segment"] segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None # Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks") raw_child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list): child_chunks_list: list[RetrievalChildChunk] | None = None
child_chunks = None if isinstance(raw_child_chunks, list):
# Sort by score descending
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
child_chunks_list = [
RetrievalChildChunk(
id=chunk["id"],
content=chunk["content"],
score=chunk.get("score", 0.0),
position=chunk["position"],
)
for chunk in sorted_chunks
]
# Extract files, ensuring it's a list or None # Extract files, ensuring it's a list or None
files = record.get("files") files = record.get("files")
@ -551,11 +578,11 @@ class RetrievalService:
# Create RetrievalSegments object # Create RetrievalSegments object
retrieval_segment = RetrievalSegments( retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files segment=segment, child_chunks=child_chunks_list, score=score, files=files
) )
result.append(retrieval_segment) result.append(retrieval_segment)
return result return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
raise e raise e
@ -565,6 +592,8 @@ class RetrievalService:
flask_app: Flask, flask_app: Flask,
retrieval_method: RetrievalMethod, retrieval_method: RetrievalMethod,
dataset: Dataset, dataset: Dataset,
all_documents: list[Document],
exceptions: list[str],
query: str | None = None, query: str | None = None,
top_k: int = 4, top_k: int = 4,
score_threshold: float | None = 0.0, score_threshold: float | None = 0.0,
@ -573,8 +602,6 @@ class RetrievalService:
weights: dict | None = None, weights: dict | None = None,
document_ids_filter: list[str] | None = None, document_ids_filter: list[str] | None = None,
attachment_id: str | None = None, attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
): ):
if not query and not attachment_id: if not query and not attachment_id:
return return
@ -647,7 +674,14 @@ class RetrievalService:
document_ids_filter=document_ids_filter, document_ids_filter=document_ids_filter,
) )
) )
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) # Use as_completed for early error propagation - cancel remaining futures on first error
if futures:
for future in concurrent.futures.as_completed(futures, timeout=300):
if future.exception():
# Cancel remaining futures to avoid unnecessary waiting
for f in futures:
f.cancel()
break
if exceptions: if exceptions:
raise ValueError(";\n".join(exceptions)) raise ValueError(";\n".join(exceptions))
@ -696,3 +730,37 @@ class RetrievalService:
} }
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id} return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
return None return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
if attachment_binding:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"segment_id": attachment_binding.segment_id,
}
)
return attachment_infos

View File

@ -289,7 +289,8 @@ class OracleVector(BaseVector):
words = pseg.cut(query) words = pseg.cut(query)
current_entity = "" current_entity = ""
for word, pos in words: for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名nt: 机构名 # `nr`: Person, `ns`: Location, `nt`: Organization
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
current_entity += word current_entity += word
else: else:
if current_entity: if current_entity:

View File

@ -255,7 +255,10 @@ class PGVector(BaseVector):
return return
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector") cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
if not cur.fetchone():
cur.execute("CREATE EXTENSION vector")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less # PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing # ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing

View File

@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# Vastbase 支持的向量维度取值范围为 [1,16000] # Vastbase supports vector dimensions in the range [1, 16,000]
if dimension <= 16000: if dimension <= 16000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)

View File

@ -25,7 +25,7 @@ class FirecrawlApp:
} }
if params: if params:
json_data.update(params) json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers) response = self._post_request(self._build_url("v2/scrape"), json_data, headers)
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
data = response_data["data"] data = response_data["data"]
@ -42,7 +42,7 @@ class FirecrawlApp:
json_data = {"url": url} json_data = {"url": url}
if params: if params:
json_data.update(params) json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers) response = self._post_request(self._build_url("v2/crawl"), json_data, headers)
if response.status_code == 200: if response.status_code == 200:
# There's also another two fields in the response: "success" (bool) and "url" (str) # There's also another two fields in the response: "success" (bool) and "url" (str)
job_id = response.json().get("id") job_id = response.json().get("id")
@ -58,7 +58,7 @@ class FirecrawlApp:
if params: if params:
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip" # Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
json_data.update(params) json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers) response = self._post_request(self._build_url("v2/map"), json_data, headers)
if response.status_code == 200: if response.status_code == 200:
return cast(dict[str, Any], response.json()) return cast(dict[str, Any], response.json())
elif response.status_code in {402, 409, 500, 429, 408}: elif response.status_code in {402, 409, 500, 429, 408}:
@ -69,7 +69,7 @@ class FirecrawlApp:
def check_crawl_status(self, job_id) -> dict[str, Any]: def check_crawl_status(self, job_id) -> dict[str, Any]:
headers = self._prepare_headers() headers = self._prepare_headers()
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers) response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
if response.status_code == 200: if response.status_code == 200:
crawl_status_response = response.json() crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed": if crawl_status_response.get("status") == "completed":
@ -120,6 +120,10 @@ class FirecrawlApp:
def _prepare_headers(self) -> dict[str, Any]: def _prepare_headers(self) -> dict[str, Any]:
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _build_url(self, path: str) -> str:
# ensure exactly one slash between base and path, regardless of user-provided base_url
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response: def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
for attempt in range(retries): for attempt in range(retries):
response = httpx.post(url, headers=headers, json=data) response = httpx.post(url, headers=headers, json=data)
@ -139,7 +143,11 @@ class FirecrawlApp:
return response return response
def _handle_error(self, response, action): def _handle_error(self, response, action):
error_message = response.json().get("error", "Unknown error occurred") try:
payload = response.json()
error_message = payload.get("error") or payload.get("message") or response.text or "Unknown error occurred"
except json.JSONDecodeError:
error_message = response.text or "Unknown error occurred"
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return] raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]: def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
@ -160,7 +168,7 @@ class FirecrawlApp:
} }
if params: if params:
json_data.update(params) json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers) response = self._post_request(self._build_url("v2/search"), json_data, headers)
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
if not response_data.get("success"): if not response_data.get("success"):

View File

@ -48,13 +48,21 @@ class NotionExtractor(BaseExtractor):
if notion_access_token: if notion_access_token:
self._notion_access_token = notion_access_token self._notion_access_token = notion_access_token
else: else:
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id) try:
if not self._notion_access_token: self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
except Exception as e:
logger.warning(
(
"Failed to get Notion access token from datasource credentials: %s, "
"falling back to environment variable NOTION_INTEGRATION_TOKEN"
),
e,
)
integration_token = dify_config.NOTION_INTEGRATION_TOKEN integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None: if integration_token is None:
raise ValueError( raise ValueError(
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`." "Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
) ) from e
self._notion_access_token = integration_token self._notion_access_token = integration_token

View File

@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor):
def _extract_images_from_docx(self, doc): def _extract_images_from_docx(self, doc):
image_count = 0 image_count = 0
image_map = {} image_map = {}
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
for r_id, rel in doc.part.rels.items(): for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref: if "image" in rel.target_ref:
@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor):
used_at=naive_utc_now(), used_at=naive_utc_now(),
) )
db.session.add(upload_file) db.session.add(upload_file)
# Use r_id as key for external images since target_part is undefined image_map[r_id] = f"![image]({base_url}/files/{upload_file.id}/file-preview)"
image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
else: else:
image_ext = rel.target_ref.split(".")[-1] image_ext = rel.target_ref.split(".")[-1]
if image_ext is None: if image_ext is None:
@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor):
used_at=naive_utc_now(), used_at=naive_utc_now(),
) )
db.session.add(upload_file) db.session.add(upload_file)
# Use target_part as key for internal images image_map[rel.target_part] = f"![image]({base_url}/files/{upload_file.id}/file-preview)"
image_map[rel.target_part] = (
f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
)
db.session.commit() db.session.commit()
return image_map return image_map

View File

@ -231,7 +231,7 @@ class BaseIndexProcessor(ABC):
if not filename: if not filename:
parsed_url = urlparse(image_url) parsed_url = urlparse(image_url)
# unquote 处理 URL 中的中文 # Decode percent-encoded characters in the URL path.
path = unquote(parsed_url.path) path = unquote(parsed_url.path)
filename = os.path.basename(path) filename = os.path.basename(path)

View File

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast from typing import Any, Union, cast
from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import and_, or_, select from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.app.app_config.entities import ( from core.app.app_config.entities import (
@ -151,20 +151,14 @@ class DatasetRetrieval:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER planning_strategy = PlanningStrategy.ROUTER
available_datasets = [] available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
if not dataset: datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
for dataset in datasets:
if dataset.available_document_count == 0 and dataset.provider != "external":
continue continue
# pass if dataset is not available
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset) available_datasets.append(dataset)
if inputs: if inputs:
inputs = {key: str(value) for key, value in inputs.items()} inputs = {key: str(value) for key, value in inputs.items()}
else: else:
@ -282,26 +276,35 @@ class DatasetRetrieval:
) )
context_files.append(attachment_info) context_files.append(attachment_info)
if show_retrieve_source: if show_retrieve_source:
dataset_ids = [record.segment.dataset_id for record in records]
document_ids = [record.segment.document_id for record in records]
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id.in_(document_ids),
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
dataset_stmt = select(Dataset).where(
Dataset.id.in_(dataset_ids),
)
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset_item = dataset_map.get(segment.dataset_id)
dataset_document_stmt = select(DatasetDocument).where( document_item = document_map.get(segment.document_id)
DatasetDocument.id == segment.document_id, if dataset_item and document_item:
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata( source = RetrievalSourceMetadata(
dataset_id=dataset.id, dataset_id=dataset_item.id,
dataset_name=dataset.name, dataset_name=dataset_item.name,
document_id=document.id, document_id=document_item.id,
document_name=document.name, document_name=document_item.name,
data_source_type=document.data_source_type, data_source_type=document_item.data_source_type,
segment_id=segment.id, segment_id=segment.id,
retriever_from=invoke_from.to_source(), retriever_from=invoke_from.to_source(),
score=record.score or 0.0, score=record.score or 0.0,
doc_metadata=document.doc_metadata, doc_metadata=document_item.doc_metadata,
) )
if invoke_from.to_source() == "dev": if invoke_from.to_source() == "dev":
@ -513,6 +516,9 @@ class DatasetRetrieval:
].embedding_model_provider ].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
with measure_time() as timer: with measure_time() as timer:
cancel_event = threading.Event()
thread_exceptions: list[Exception] = []
if query: if query:
query_thread = threading.Thread( query_thread = threading.Thread(
target=self._multiple_retrieve_thread, target=self._multiple_retrieve_thread,
@ -531,6 +537,8 @@ class DatasetRetrieval:
"score_threshold": score_threshold, "score_threshold": score_threshold,
"query": query, "query": query,
"attachment_id": None, "attachment_id": None,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
}, },
) )
all_threads.append(query_thread) all_threads.append(query_thread)
@ -554,12 +562,25 @@ class DatasetRetrieval:
"score_threshold": score_threshold, "score_threshold": score_threshold,
"query": None, "query": None,
"attachment_id": attachment_id, "attachment_id": attachment_id,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
}, },
) )
all_threads.append(attachment_thread) all_threads.append(attachment_thread)
attachment_thread.start() attachment_thread.start()
for thread in all_threads:
thread.join() # Poll threads with short timeout to detect errors quickly (fail-fast)
while any(t.is_alive() for t in all_threads):
for thread in all_threads:
thread.join(timeout=0.1)
if thread_exceptions:
cancel_event.set()
break
if thread_exceptions:
break
if thread_exceptions:
raise thread_exceptions[0]
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents: if all_documents:
@ -1033,7 +1054,7 @@ class DatasetRetrieval:
if automatic_metadata_filters: if automatic_metadata_filters:
conditions = [] conditions = []
for sequence, filter in enumerate(automatic_metadata_filters): for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func( self.process_metadata_filter_func(
sequence, sequence,
filter.get("condition"), # type: ignore filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore filter.get("metadata_name"), # type: ignore
@ -1069,7 +1090,7 @@ class DatasetRetrieval:
value=expected_value, value=expected_value,
) )
) )
filters = self._process_metadata_filter_func( filters = self.process_metadata_filter_func(
sequence, sequence,
condition.comparison_operator, condition.comparison_operator,
metadata_name, metadata_name,
@ -1165,8 +1186,9 @@ class DatasetRetrieval:
return None return None
return automatic_metadata_filters return automatic_metadata_filters
def _process_metadata_filter_func( @classmethod
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list def process_metadata_filter_func(
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
): ):
if value is None and condition not in ("empty", "not empty"): if value is None and condition not in ("empty", "not empty"):
return filters return filters
@ -1215,6 +1237,20 @@ class DatasetRetrieval:
case "" | ">=": case "" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value) filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case "in" | "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
# `field in []` is False, `field not in []` is True
filters.append(literal(condition == "not in"))
else:
op = json_field.in_ if condition == "in" else json_field.notin_
filters.append(op(value_list))
case _: case _:
pass pass
@ -1386,40 +1422,53 @@ class DatasetRetrieval:
score_threshold: float, score_threshold: float,
query: str | None, query: str | None,
attachment_id: str | None, attachment_id: str | None,
cancel_event: threading.Event | None = None,
thread_exceptions: list[Exception] | None = None,
): ):
with flask_app.app_context(): try:
threads = [] with flask_app.app_context():
all_documents_item: list[Document] = [] threads = []
index_type = None all_documents_item: list[Document] = []
for dataset in available_datasets: index_type = None
index_type = dataset.indexing_technique for dataset in available_datasets:
document_ids_filter = None # Check for cancellation signal
if dataset.provider != "external": if cancel_event and cancel_event.is_set():
if metadata_condition and not metadata_filter_document_ids: break
continue index_type = dataset.indexing_technique
if metadata_filter_document_ids: document_ids_filter = None
document_ids = metadata_filter_document_ids.get(dataset.id, []) if dataset.provider != "external":
if document_ids: if metadata_condition and not metadata_filter_document_ids:
document_ids_filter = document_ids
else:
continue continue
retrieval_thread = threading.Thread( if metadata_filter_document_ids:
target=self._retriever, document_ids = metadata_filter_document_ids.get(dataset.id, [])
kwargs={ if document_ids:
"flask_app": flask_app, document_ids_filter = document_ids
"dataset_id": dataset.id, else:
"query": query, continue
"top_k": top_k, retrieval_thread = threading.Thread(
"all_documents": all_documents_item, target=self._retriever,
"document_ids_filter": document_ids_filter, kwargs={
"metadata_condition": metadata_condition, "flask_app": flask_app,
"attachment_ids": [attachment_id] if attachment_id else None, "dataset_id": dataset.id,
}, "query": query,
) "top_k": top_k,
threads.append(retrieval_thread) "all_documents": all_documents_item,
retrieval_thread.start() "document_ids_filter": document_ids_filter,
for thread in threads: "metadata_condition": metadata_condition,
thread.join() "attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
# Poll threads with short timeout to respond quickly to cancellation
while any(t.is_alive() for t in threads):
for thread in threads:
thread.join(timeout=0.1)
if cancel_event and cancel_event.is_set():
break
if cancel_event and cancel_event.is_set():
break
if reranking_enable: if reranking_enable:
# do rerank for searched documents # do rerank for searched documents
@ -1452,3 +1501,8 @@ class DatasetRetrieval:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item: if all_documents_item:
all_documents.extend(all_documents_item) all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()
if thread_exceptions is not None:
thread_exceptions.append(e)

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import codecs
import re import re
from typing import Any from typing import Any
@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any): def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
"""Create a new TextSplitter.""" """Create a new TextSplitter."""
super().__init__(**kwargs) super().__init__(**kwargs)
self._fixed_separator = fixed_separator self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""] self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""]
def split_text(self, text: str) -> list[str]: def split_text(self, text: str) -> list[str]:
@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = re.split(r" +", text) splits = re.split(r" +", text)
else: else:
splits = text.split(separator) splits = text.split(separator)
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] if self._keep_separator:
splits = [s + separator for s in splits[:-1]] + splits[-1:]
else: else:
splits = list(text) splits = list(text)
if separator == "\n": if separator == "\n":
@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = [s for s in splits if (s not in {"", "\n"})] splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = [] _good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits _good_splits_lengths = [] # cache the lengths of the splits
_separator = separator if self._keep_separator else "" _separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits) s_lens = self._length_function(splits)
if separator != "": if separator != "":
for s, s_len in zip(splits, s_lens): for s, s_len in zip(splits, s_lens):

View File

@ -153,11 +153,11 @@ class ToolInvokeMessage(BaseModel):
@classmethod @classmethod
def transform_variable_value(cls, values): def transform_variable_value(cls, values):
""" """
Only basic types and lists are allowed. Only basic types, lists, and None are allowed.
""" """
value = values.get("variable_value") value = values.get("variable_value")
if not isinstance(value, dict | list | str | int | float | bool): if value is not None and not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types and lists are allowed.") raise ValueError("Only basic types, lists, and None are allowed.")
# if stream is true, the value must be a string # if stream is true, the value must be a string
if values.get("stream"): if values.get("stream"):

View File

@ -6,7 +6,15 @@ from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError from core.mcp.error import MCPConnectionError
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
)
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@ -53,10 +61,19 @@ class MCPTool(Tool):
for content in result.content: for content in result.content:
if isinstance(content, TextContent): if isinstance(content, TextContent):
yield from self._process_text_content(content) yield from self._process_text_content(content)
elif isinstance(content, ImageContent): elif isinstance(content, ImageContent | AudioContent):
yield self._process_image_content(content) yield self.create_blob_message(
elif isinstance(content, AudioContent): blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
yield self._process_audio_content(content) )
elif isinstance(content, EmbeddedResource):
resource = content.resource
if isinstance(resource, TextResourceContents):
yield self.create_text_message(resource.text)
elif isinstance(resource, BlobResourceContents):
mime_type = resource.mimeType or "application/octet-stream"
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
else:
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
else: else:
logger.warning("Unsupported content type=%s", type(content)) logger.warning("Unsupported content type=%s", type(content))
@ -101,14 +118,6 @@ class MCPTool(Tool):
for item in json_list: for item in json_list:
yield self.create_json_message(item) yield self.create_json_message(item)
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool( return MCPTool(
entity=self.entity, entity=self.entity,

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
@ -47,33 +48,30 @@ class WorkflowToolProviderController(ToolProviderController):
@classmethod @classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
with Session(db.engine, expire_on_commit=False) as session, session.begin(): with session_factory.create_session() as session, session.begin():
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None app = session.get(App, db_provider.app_id)
if not provider:
raise ValueError("workflow provider not found")
app = session.get(App, provider.app_id)
if not app: if not app:
raise ValueError("app not found") raise ValueError("app not found")
user = session.get(Account, provider.user_id) if provider.user_id else None user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
controller = WorkflowToolProviderController( controller = WorkflowToolProviderController(
entity=ToolProviderEntity( entity=ToolProviderEntity(
identity=ToolProviderIdentity( identity=ToolProviderIdentity(
author=user.name if user else "", author=user.name if user else "",
name=provider.label, name=db_provider.label,
label=I18nObject(en_US=provider.label, zh_Hans=provider.label), label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
description=I18nObject(en_US=provider.description, zh_Hans=provider.description), description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=provider.icon, icon=db_provider.icon,
), ),
credentials_schema=[], credentials_schema=[],
plugin_id=None, plugin_id=None,
), ),
provider_id=provider.id or "", provider_id="",
) )
controller.tools = [ controller.tools = [
controller._get_db_provider_tool(provider, app, session=session, user=user), controller._get_db_provider_tool(db_provider, app, session=session, user=user),
] ]
return controller return controller

View File

@ -67,12 +67,16 @@ def create_trigger_provider_encrypter_for_subscription(
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str): def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
cache = TriggerProviderCredentialsCache( TriggerProviderCredentialsCache(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_id=provider_id, provider_id=provider_id,
credential_id=subscription_id, credential_id=subscription_id,
) ).delete()
cache.delete() TriggerProviderPropertiesCache(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_id=subscription_id,
).delete()
def create_trigger_provider_encrypter_for_properties( def create_trigger_provider_encrypter_for_properties(

Some files were not shown because too many files have changed in this diff Show More