Merge branch 'main' into feat/step-one-refactor

This commit is contained in:
Coding On Star 2025-12-29 14:48:06 +08:00 committed by GitHub
commit 7409359afe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
414 changed files with 14781 additions and 3878 deletions

8
.claude/settings.json Normal file
View File

@ -0,0 +1,8 @@
{
"enabledPlugins": {
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true
}
}

View File

@ -1,19 +0,0 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -0,0 +1,483 @@
---
name: component-refactoring
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
---
# Dify Component Refactoring Skill
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
## Quick Reference
### Commands (run from `web/`)
Use paths relative to `web/` (e.g., `app/components/...`).
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
```bash
cd web
# Generate refactoring prompt
pnpm refactor-component <path>
# Output refactoring analysis as JSON
pnpm refactor-component <path> --json
# Generate testing prompt (after refactoring)
pnpm analyze-component <path>
# Output testing analysis as JSON
pnpm analyze-component <path> --json
```
### Complexity Analysis
```bash
# Analyze component complexity
pnpm analyze-component <path> --json
# Key metrics to check:
# - complexity: normalized score 0-100 (target < 50)
# - maxComplexity: highest single function complexity
# - lineCount: total lines (target < 300)
```
### Complexity Score Interpretation
| Score | Level | Action |
|-------|-------|--------|
| 0-25 | 🟢 Simple | Ready for testing |
| 26-50 | 🟡 Medium | Consider minor refactoring |
| 51-75 | 🟠 Complex | **Refactor before testing** |
| 76-100 | 🔴 Very Complex | **Must refactor** |
## Core Refactoring Patterns
### Pattern 1: Extract Custom Hooks
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
```typescript
// ❌ Before: Complex state logic in component
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// 50+ lines of state management logic...
return <div>...</div>
}
// ✅ After: Extract to custom hook
// hooks/use-model-config.ts
export const useModelConfig = (appId: string) => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// Related state management logic here
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
}
// Component becomes cleaner
const Configuration: FC = () => {
const { modelConfig, setModelConfig } = useModelConfig(appId)
return <div>...</div>
}
```
**Dify Examples**:
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
- `web/app/components/app/configuration/debug/hooks.tsx`
- `web/app/components/workflow/hooks/use-workflow.ts`
### Pattern 2: Extract Sub-Components
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
```typescript
// ❌ Before: Monolithic JSX with multiple sections
const AppInfo = () => {
return (
<div>
{/* 100 lines of header UI */}
{/* 100 lines of operations UI */}
{/* 100 lines of modals */}
</div>
)
}
// ✅ After: Split into focused components
// app-info/
// ├── index.tsx (orchestration only)
// ├── app-header.tsx (header UI)
// ├── app-operations.tsx (operations UI)
// └── app-modals.tsx (modal management)
const AppInfo = () => {
const { showModal, setShowModal } = useAppInfoModals()
return (
<div>
<AppHeader appDetail={appDetail} />
<AppOperations onAction={handleAction} />
<AppModals show={showModal} onClose={() => setShowModal(null)} />
</div>
)
}
```
**Dify Examples**:
- `web/app/components/app/configuration/` directory structure
- `web/app/components/workflow/nodes/` per-node organization
### Pattern 3: Simplify Conditional Logic
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
```typescript
// ❌ Before: Deeply nested conditionals
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh />
case LanguagesSupported[7]:
return <TemplateChatJa />
default:
return <TemplateChatEn />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
// Another 15 lines...
}
// More conditions...
}, [appDetail, locale])
// ✅ After: Use lookup tables + early returns
const TEMPLATE_MAP = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
// ...
},
}
const Template = useMemo(() => {
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
if (!modeTemplates) return null
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
### Pattern 4: Extract API/Data Logic
**When**: Component directly handles API calls, data transformation, or complex async operations.
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
```typescript
// ❌ Before: API logic in component
const MCPServiceCard = () => {
const [basicAppConfig, setBasicAppConfig] = useState({})
useEffect(() => {
if (isBasicApp && appId) {
(async () => {
const res = await fetchAppDetail({ url: '/apps', id: appId })
setBasicAppConfig(res?.model_config || {})
})()
}
}, [appId, isBasicApp])
// More API-related logic...
}
// ✅ After: Extract to data hook using React Query
// use-app-config.ts
import { useQuery } from '@tanstack/react-query'
import { get } from '@/service/base'
const NAME_SPACE = 'appConfig'
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
return useQuery({
enabled: isBasicApp && !!appId,
queryKey: [NAME_SPACE, 'detail', appId],
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || {},
})
}
// Component becomes cleaner
const MCPServiceCard = () => {
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
// UI only
}
```
**React Query Best Practices in Dify**:
- Define `NAME_SPACE` for query key organization
- Use `enabled` option for conditional fetching
- Use `select` for data transformation
- Export invalidation hooks: `useInvalidXxx`
**Dify Examples**:
- `web/service/use-workflow.ts`
- `web/service/use-common.ts`
- `web/service/knowledge/use-dataset.ts`
- `web/service/knowledge/use-document.ts`
### Pattern 5: Extract Modal/Dialog Management
**When**: Component manages multiple modals with complex open/close states.
**Dify Convention**: Modals should be extracted with their state management.
```typescript
// ❌ Before: Multiple modal states in component
const AppInfo = () => {
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
const [showSwitchModal, setShowSwitchModal] = useState(false)
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
// 5+ more modal states...
}
// ✅ After: Extract to modal management hook
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
const useAppInfoModals = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
const closeModal = useCallback(() => setActiveModal(null), [])
return {
activeModal,
openModal,
closeModal,
isOpen: (type: ModalType) => activeModal === type,
}
}
```
### Pattern 6: Extract Form Logic
**When**: Complex form validation, submission handling, or field transformation.
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
```typescript
// ✅ Use existing form infrastructure
import { useAppForm } from '@/app/components/base/form'
const ConfigForm = () => {
const form = useAppForm({
defaultValues: { name: '', description: '' },
onSubmit: handleSubmit,
})
return <form.Provider>...</form.Provider>
}
```
## Dify-Specific Refactoring Guidelines
### 1. Context Provider Extraction
**When**: Component provides complex context values with multiple states.
```typescript
// ❌ Before: Large context value object
const value = {
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
// 50+ more properties...
}
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
// ✅ After: Split into domain-specific contexts
<ModelConfigProvider value={modelConfigValue}>
<DatasetConfigProvider value={datasetConfigValue}>
<UIConfigProvider value={uiConfigValue}>
{children}
</UIConfigProvider>
</DatasetConfigProvider>
</ModelConfigProvider>
```
**Dify Reference**: `web/context/` directory structure
### 2. Workflow Node Components
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
**Conventions**:
- Keep node logic in `use-interactions.ts`
- Extract panel UI to separate files
- Use `_base` components for common patterns
```
nodes/<node-type>/
├── index.tsx # Node registration
├── node.tsx # Node visual component
├── panel.tsx # Configuration panel
├── use-interactions.ts # Node-specific hooks
└── types.ts # Type definitions
```
### 3. Configuration Components
**When**: Refactoring app configuration components.
**Conventions**:
- Separate config sections into subdirectories
- Use existing patterns from `web/app/components/app/configuration/`
- Keep feature toggles in dedicated components
### 4. Tool/Plugin Components
**When**: Refactoring tool-related components (`web/app/components/tools/`).
**Conventions**:
- Follow existing modal patterns
- Use service hooks from `web/service/use-tools.ts`
- Keep provider-specific logic isolated
## Refactoring Workflow
### Step 1: Generate Refactoring Prompt
```bash
pnpm refactor-component <path>
```
This command will:
- Analyze component complexity and features
- Identify specific refactoring actions needed
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
- Provide detailed requirements based on detected patterns
### Step 2: Analyze Details
```bash
pnpm analyze-component <path> --json
```
Identify:
- Total complexity score
- Max function complexity
- Line count
- Features detected (state, effects, API, etc.)
### Step 3: Plan
Create a refactoring plan based on detected features:
| Detected Feature | Refactoring Action |
|------------------|-------------------|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
| `hasAPI: true` | Extract data/service hook |
| `hasEvents: true` (many) | Extract event handlers |
| `lineCount > 300` | Split into sub-components |
| `maxComplexity > 50` | Simplify conditional logic |
### Step 4: Execute Incrementally
1. **Extract one piece at a time**
2. **Run lint, type-check, and tests after each extraction**
3. **Verify functionality before next step**
```
For each extraction:
┌────────────────────────────────────────┐
│ 1. Extract code │
│ 2. Run: pnpm lint:fix │
│ 3. Run: pnpm type-check:tsgo │
│ 4. Run: pnpm test │
│ 5. Test functionality manually │
│ 6. PASS? → Next extraction │
│ FAIL? → Fix before continuing │
└────────────────────────────────────────┘
```
### Step 5: Verify
After refactoring:
```bash
# Re-run refactor command to verify improvements
pnpm refactor-component <path>
# If complexity < 25 and lines < 200, you'll see:
# ✅ COMPONENT IS WELL-STRUCTURED
# For detailed metrics:
pnpm analyze-component <path> --json
# Target metrics:
# - complexity < 50
# - lineCount < 300
# - maxComplexity < 30
```
## Common Mistakes to Avoid
### ❌ Over-Engineering
```typescript
// ❌ Too many tiny hooks
const useButtonText = () => useState('Click')
const useButtonDisabled = () => useState(false)
const useButtonLoading = () => useState(false)
// ✅ Cohesive hook with related state
const useButtonState = () => {
const [text, setText] = useState('Click')
const [disabled, setDisabled] = useState(false)
const [loading, setLoading] = useState(false)
return { text, setText, disabled, setDisabled, loading, setLoading }
}
```
### ❌ Breaking Existing Patterns
- Follow existing directory structures
- Maintain naming conventions
- Preserve export patterns for compatibility
### ❌ Premature Abstraction
- Only extract when there's clear complexity benefit
- Don't create abstractions for single-use code
- Keep refactored code in the same domain area
## References
### Dify Codebase Examples
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
- **Component splitting**: `web/app/components/app/configuration/`
- **Service hooks**: `web/service/use-*.ts`
- **Workflow patterns**: `web/app/components/workflow/hooks/`
- **Form patterns**: `web/app/components/base/form/`
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/testing/testing.md` - Testing specification

View File

@ -0,0 +1,493 @@
# Complexity Reduction Patterns
This document provides patterns for reducing cognitive complexity in Dify React components.
## Understanding Complexity
### SonarJS Cognitive Complexity
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
- **Total Complexity**: Sum of all functions' complexity in the file
- **Max Complexity**: Highest single function complexity
### What Increases Complexity
| Pattern | Complexity Impact |
|---------|-------------------|
| `if/else` | +1 per branch |
| Nested conditions | +1 per nesting level |
| `switch/case` | +1 per case |
| `for/while/do` | +1 per loop |
| `&&`/`||` chains | +1 per operator |
| Nested callbacks | +1 per nesting level |
| `try/catch` | +1 per catch |
| Ternary expressions | +1 per nesting |
## Pattern 1: Replace Conditionals with Lookup Tables
**Before** (complexity: ~15):
```typescript
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateChatJa appDetail={appDetail} />
default:
return <TemplateChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateAdvancedChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateAdvancedChatJa appDetail={appDetail} />
default:
return <TemplateAdvancedChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
// Similar pattern...
}
return null
}, [appDetail, locale])
```
**After** (complexity: ~3):
```typescript
// Define lookup table outside component
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
default: TemplateAdvancedChatEn,
},
[AppModeEnum.WORKFLOW]: {
[LanguagesSupported[1]]: TemplateWorkflowZh,
[LanguagesSupported[7]]: TemplateWorkflowJa,
default: TemplateWorkflowEn,
},
// ...
}
// Clean component logic
const Template = useMemo(() => {
if (!appDetail?.mode) return null
const templates = TEMPLATE_MAP[appDetail.mode]
if (!templates) return null
const TemplateComponent = templates[locale] ?? templates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
## Pattern 2: Use Early Returns
**Before** (complexity: ~10):
```typescript
const handleSubmit = () => {
if (isValid) {
if (hasChanges) {
if (isConnected) {
submitData()
} else {
showConnectionError()
}
} else {
showNoChangesMessage()
}
} else {
showValidationError()
}
}
```
**After** (complexity: ~4):
```typescript
const handleSubmit = () => {
if (!isValid) {
showValidationError()
return
}
if (!hasChanges) {
showNoChangesMessage()
return
}
if (!isConnected) {
showConnectionError()
return
}
submitData()
}
```
## Pattern 3: Extract Complex Conditions
**Before** (complexity: high):
```typescript
const canPublish = (() => {
if (mode !== AppModeEnum.COMPLETION) {
if (!isAdvancedMode)
return true
if (modelModeType === ModelModeType.completion) {
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
return false
return true
}
return true
}
return !promptEmpty
})()
```
**After** (complexity: lower):
```typescript
// Extract to named functions
const canPublishInCompletionMode = () => !promptEmpty
const canPublishInChatMode = () => {
if (!isAdvancedMode) return true
if (modelModeType !== ModelModeType.completion) return true
return hasSetBlockStatus.history && hasSetBlockStatus.query
}
// Clean main logic
const canPublish = mode === AppModeEnum.COMPLETION
? canPublishInCompletionMode()
: canPublishInChatMode()
```
## Pattern 4: Replace Chained Ternaries
**Before** (complexity: ~5):
```typescript
const statusText = serverActivated
? t('status.running')
: serverPublished
? t('status.inactive')
: appUnpublished
? t('status.unpublished')
: t('status.notConfigured')
```
**After** (complexity: ~2):
```typescript
const getStatusText = () => {
if (serverActivated) return t('status.running')
if (serverPublished) return t('status.inactive')
if (appUnpublished) return t('status.unpublished')
return t('status.notConfigured')
}
const statusText = getStatusText()
```
Or use lookup:
```typescript
const STATUS_TEXT_MAP = {
running: 'status.running',
inactive: 'status.inactive',
unpublished: 'status.unpublished',
notConfigured: 'status.notConfigured',
} as const
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
if (serverActivated) return 'running'
if (serverPublished) return 'inactive'
if (appUnpublished) return 'unpublished'
return 'notConfigured'
}
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
```
## Pattern 5: Flatten Nested Loops
**Before** (complexity: high):
```typescript
const processData = (items: Item[]) => {
const results: ProcessedItem[] = []
for (const item of items) {
if (item.isValid) {
for (const child of item.children) {
if (child.isActive) {
for (const prop of child.properties) {
if (prop.value !== null) {
results.push({
itemId: item.id,
childId: child.id,
propValue: prop.value,
})
}
}
}
}
}
}
return results
}
```
**After** (complexity: lower):
```typescript
// Use functional approach
const processData = (items: Item[]) => {
return items
.filter(item => item.isValid)
.flatMap(item =>
item.children
.filter(child => child.isActive)
.flatMap(child =>
child.properties
.filter(prop => prop.value !== null)
.map(prop => ({
itemId: item.id,
childId: child.id,
propValue: prop.value,
}))
)
)
}
```
## Pattern 6: Extract Event Handler Logic
**Before** (complexity: high in component):
```typescript
const Component = () => {
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
let newDatasets = data
if (data.find(item => !item.name)) {
const newSelected = produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const newItem = dataSets.find(i => i.id === item.id)
if (newItem)
draft[index] = newItem
}
})
})
setDataSets(newSelected)
newDatasets = newSelected
}
else {
setDataSets(data)
}
hideSelectDataSet()
// 40 more lines of logic...
}
return <div>...</div>
}
```
**After** (complexity: lower):
```typescript
// Extract to hook or utility
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
const normalizeSelection = (data: DataSet[]) => {
const hasUnloadedItem = data.some(item => !item.name)
if (!hasUnloadedItem) return data
return produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const existing = dataSets.find(i => i.id === item.id)
if (existing) draft[index] = existing
}
})
})
}
const hasSelectionChanged = (newData: DataSet[]) => {
return !isEqual(
newData.map(item => item.id),
dataSets.map(item => item.id)
)
}
return { normalizeSelection, hasSelectionChanged }
}
// Component becomes cleaner
const Component = () => {
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
const handleSelect = (data: DataSet[]) => {
if (!hasSelectionChanged(data)) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
const normalized = normalizeSelection(data)
setDataSets(normalized)
hideSelectDataSet()
}
return <div>...</div>
}
```
## Pattern 7: Reduce Boolean Logic Complexity
**Before** (complexity: ~8):
```typescript
const toggleDisabled = hasInsufficientPermissions
|| appUnpublished
|| missingStartNode
|| triggerModeDisabled
|| (isAdvancedApp && !currentWorkflow?.graph)
|| (isBasicApp && !basicAppConfig.updated_at)
```
**After** (complexity: ~3):
```typescript
// Extract meaningful boolean functions
const isAppReady = () => {
if (isAdvancedApp) return !!currentWorkflow?.graph
return !!basicAppConfig.updated_at
}
const hasRequiredPermissions = () => {
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
}
const canToggle = () => {
if (!hasRequiredPermissions()) return false
if (!isAppReady()) return false
if (missingStartNode) return false
if (triggerModeDisabled) return false
return true
}
const toggleDisabled = !canToggle()
```
## Pattern 8: Simplify useMemo/useCallback Dependencies
**Before** (complexity: multiple recalculations):
```typescript
const payload = useMemo(() => {
let parameters: Parameter[] = []
let outputParameters: OutputParameter[] = []
if (!published) {
parameters = (inputs || []).map((item) => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
outputParameters = (outputs || []).map((item) => ({
name: item.variable,
description: '',
type: item.value_type,
}))
}
else if (detail && detail.tool) {
parameters = (inputs || []).map((item) => ({
// Complex transformation...
}))
outputParameters = (outputs || []).map((item) => ({
// Complex transformation...
}))
}
return {
icon: detail?.icon || icon,
label: detail?.label || name,
// ...more fields
}
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
```
**After** (complexity: separated concerns):
```typescript
// Separate transformations
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
return useMemo(() => {
if (!published) {
return inputs.map(item => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
}
if (!detail?.tool) return []
return inputs.map(item => ({
name: item.variable,
required: item.required,
type: item.type === 'paragraph' ? 'string' : item.type,
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
}))
}, [inputs, detail, published])
}
// Component uses hook
const parameters = useParameterTransform(inputs, detail, published)
const outputParameters = useOutputTransform(outputs, detail, published)
const payload = useMemo(() => ({
icon: detail?.icon || icon,
label: detail?.label || name,
parameters,
outputParameters,
// ...
}), [detail, icon, name, parameters, outputParameters])
```
## Target Metrics After Refactoring
| Metric | Target |
|--------|--------|
| Total Complexity | < 50 |
| Max Function Complexity | < 30 |
| Function Length | < 30 lines |
| Nesting Depth | ≤ 3 levels |
| Conditional Chains | ≤ 3 conditions |

View File

@ -0,0 +1,477 @@
# Component Splitting Patterns
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
## When to Split Components
Split a component when you identify:
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
1. **Repeated patterns** - Similar UI structures used multiple times
1. **300+ lines** - Component exceeds manageable size
1. **Modal clusters** - Multiple modals rendered in one component
## Splitting Strategies
### Strategy 1: Section-Based Splitting
Identify visual sections and extract each as a component.
```typescript
// ❌ Before: Monolithic component (500+ lines)
const ConfigurationPage = () => {
return (
<div>
{/* Header Section - 50 lines */}
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher ... />
</div>
</div>
{/* Config Section - 200 lines */}
<div className="config">
<Config />
</div>
{/* Debug Section - 150 lines */}
<div className="debug">
<Debug ... />
</div>
{/* Modals Section - 100 lines */}
{showSelectDataSet && <SelectDataSet ... />}
{showHistoryModal && <EditHistoryModal ... />}
{showUseGPT4Confirm && <Confirm ... />}
</div>
)
}
// ✅ After: Split into focused components
// configuration/
// ├── index.tsx (orchestration)
// ├── configuration-header.tsx
// ├── configuration-content.tsx
// ├── configuration-debug.tsx
// └── configuration-modals.tsx
// configuration-header.tsx
interface ConfigurationHeaderProps {
isAdvancedMode: boolean
onPublish: () => void
}
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
isAdvancedMode,
onPublish,
}) => {
const { t } = useTranslation()
return (
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher onPublish={onPublish} />
</div>
</div>
)
}
// index.tsx (orchestration only)
const ConfigurationPage = () => {
const { modelConfig, setModelConfig } = useModelConfig()
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
<ConfigurationHeader
isAdvancedMode={isAdvancedMode}
onPublish={handlePublish}
/>
<ConfigurationContent
modelConfig={modelConfig}
onConfigChange={setModelConfig}
/>
{!isMobile && (
<ConfigurationDebug
inputs={inputs}
onSetting={handleSetting}
/>
)}
<ConfigurationModals
activeModal={activeModal}
onClose={closeModal}
/>
</div>
)
}
```
### Strategy 2: Conditional Block Extraction
Extract large conditional rendering blocks.
```typescript
// ❌ Before: Large conditional blocks
const AppInfo = () => {
return (
<div>
{expand ? (
<div className="expanded">
{/* 100 lines of expanded view */}
</div>
) : (
<div className="collapsed">
{/* 50 lines of collapsed view */}
</div>
)}
</div>
)
}
// ✅ After: Separate view components
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="expanded">
{/* Clean, focused expanded view */}
</div>
)
}
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="collapsed">
{/* Clean, focused collapsed view */}
</div>
)
}
const AppInfo = () => {
return (
<div>
{expand
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
}
</div>
)
}
```
### Strategy 3: Modal Extraction
Extract modals with their trigger logic.
```typescript
// ❌ Before: Multiple modals in one component
const AppInfo = () => {
const [showEdit, setShowEdit] = useState(false)
const [showDuplicate, setShowDuplicate] = useState(false)
const [showDelete, setShowDelete] = useState(false)
const [showSwitch, setShowSwitch] = useState(false)
const onEdit = async (data) => { /* 20 lines */ }
const onDuplicate = async (data) => { /* 20 lines */ }
const onDelete = async () => { /* 15 lines */ }
return (
<div>
{/* Main content */}
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
{showSwitch && <SwitchModal ... />}
</div>
)
}
// ✅ After: Modal manager component
// app-info-modals.tsx
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
interface AppInfoModalsProps {
appDetail: AppDetail
activeModal: ModalType
onClose: () => void
onSuccess: () => void
}
const AppInfoModals: FC<AppInfoModalsProps> = ({
appDetail,
activeModal,
onClose,
onSuccess,
}) => {
const handleEdit = async (data) => { /* logic */ }
const handleDuplicate = async (data) => { /* logic */ }
const handleDelete = async () => { /* logic */ }
return (
<>
{activeModal === 'edit' && (
<EditModal
appDetail={appDetail}
onConfirm={handleEdit}
onClose={onClose}
/>
)}
{activeModal === 'duplicate' && (
<DuplicateModal
appDetail={appDetail}
onConfirm={handleDuplicate}
onClose={onClose}
/>
)}
{activeModal === 'delete' && (
<DeleteConfirm
onConfirm={handleDelete}
onClose={onClose}
/>
)}
{activeModal === 'switch' && (
<SwitchModal
appDetail={appDetail}
onClose={onClose}
/>
)}
</>
)
}
// Parent component
const AppInfo = () => {
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
{/* Main content with openModal triggers */}
<Button onClick={() => openModal('edit')}>Edit</Button>
<AppInfoModals
appDetail={appDetail}
activeModal={activeModal}
onClose={closeModal}
onSuccess={handleSuccess}
/>
</div>
)
}
```
### Strategy 4: List Item Extraction
Extract repeated item rendering.
```typescript
// ❌ Before: Inline item rendering
const OperationsList = () => {
return (
<div>
{operations.map(op => (
<div key={op.id} className="operation-item">
<span className="icon">{op.icon}</span>
<span className="title">{op.title}</span>
<span className="description">{op.description}</span>
<button onClick={() => op.onClick()}>
{op.actionLabel}
</button>
{op.badge && <Badge>{op.badge}</Badge>}
{/* More complex rendering... */}
</div>
))}
</div>
)
}
// ✅ After: Extracted item component
interface OperationItemProps {
operation: Operation
onAction: (id: string) => void
}
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
return (
<div className="operation-item">
<span className="icon">{operation.icon}</span>
<span className="title">{operation.title}</span>
<span className="description">{operation.description}</span>
<button onClick={() => onAction(operation.id)}>
{operation.actionLabel}
</button>
{operation.badge && <Badge>{operation.badge}</Badge>}
</div>
)
}
const OperationsList = () => {
const handleAction = useCallback((id: string) => {
const op = operations.find(o => o.id === id)
op?.onClick()
}, [operations])
return (
<div>
{operations.map(op => (
<OperationItem
key={op.id}
operation={op}
onAction={handleAction}
/>
))}
</div>
)
}
```
## Directory Structure Patterns
### Pattern A: Flat Structure (Simple Components)
For components with 2-3 sub-components:
```
component-name/
├── index.tsx # Main component
├── sub-component-a.tsx
├── sub-component-b.tsx
└── types.ts # Shared types
```
### Pattern B: Nested Structure (Complex Components)
For components with many sub-components:
```
component-name/
├── index.tsx # Main orchestration
├── types.ts # Shared types
├── hooks/
│ ├── use-feature-a.ts
│ └── use-feature-b.ts
├── components/
│ ├── header/
│ │ └── index.tsx
│ ├── content/
│ │ └── index.tsx
│ └── modals/
│ └── index.tsx
└── utils/
└── helpers.ts
```
### Pattern C: Feature-Based Structure (Dify Standard)
Following Dify's existing patterns:
```
configuration/
├── index.tsx # Main page component
├── base/ # Base/shared components
│ ├── feature-panel/
│ ├── group-name/
│ └── operation-btn/
├── config/ # Config section
│ ├── index.tsx
│ ├── agent/
│ └── automatic/
├── dataset-config/ # Dataset section
│ ├── index.tsx
│ ├── card-item/
│ └── params-config/
├── debug/ # Debug section
│ ├── index.tsx
│ └── hooks.tsx
└── hooks/ # Shared hooks
└── use-advanced-prompt-config.ts
```
## Props Design
### Minimal Props Principle
Pass only what's needed:
```typescript
// ❌ Bad: Passing entire objects when only some fields needed
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
// ✅ Good: Destructure to minimum required
<ConfigHeader
appName={appDetail.name}
isAdvancedMode={modelConfig.isAdvanced}
onPublish={handlePublish}
/>
```
### Callback Props Pattern
Use callbacks for child-to-parent communication:
```typescript
// Parent
const Parent = () => {
const [value, setValue] = useState('')
return (
<Child
value={value}
onChange={setValue}
onSubmit={handleSubmit}
/>
)
}
// Child
interface ChildProps {
value: string
onChange: (value: string) => void
onSubmit: () => void
}
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
return (
<div>
<input value={value} onChange={e => onChange(e.target.value)} />
<button onClick={onSubmit}>Submit</button>
</div>
)
}
```
### Render Props for Flexibility
When sub-components need parent context:
```typescript
interface ListProps<T> {
items: T[]
renderItem: (item: T, index: number) => React.ReactNode
renderEmpty?: () => React.ReactNode
}
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
if (items.length === 0 && renderEmpty) {
return <>{renderEmpty()}</>
}
return (
<div>
{items.map((item, index) => renderItem(item, index))}
</div>
)
}
// Usage
<List
items={operations}
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
renderEmpty={() => <EmptyState message="No operations" />}
/>
```

View File

@ -0,0 +1,317 @@
# Hook Extraction Patterns
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
## When to Extract Hooks
Extract a custom hook when you identify:
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
1. **Business logic** - Data transformations, validations, or calculations
1. **Reusable patterns** - Logic that appears in multiple components
## Extraction Process
### Step 1: Identify State Groups
Look for state variables that are logically related:
```typescript
// ❌ These belong together - extract to hook
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
// These are model-related state that should be in useModelConfig()
```
### Step 2: Identify Related Effects
Find effects that modify the grouped state:
```typescript
// ❌ These effects belong with the state above
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties.mode
if (mode) {
const newModelConfig = produce(modelConfig, (draft) => {
draft.mode = mode
})
setModelConfig(newModelConfig)
}
}
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
```
### Step 3: Create the Hook
```typescript
// hooks/use-model-config.ts
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { ModelConfig } from '@/models/debug'
import { produce } from 'immer'
import { useEffect, useState } from 'react'
import { ModelModeType } from '@/types/app'
interface UseModelConfigParams {
initialConfig?: Partial<ModelConfig>
currModel?: { model_properties?: { mode?: ModelModeType } }
hasFetchedDetail: boolean
}
interface UseModelConfigReturn {
modelConfig: ModelConfig
setModelConfig: (config: ModelConfig) => void
completionParams: FormValue
setCompletionParams: (params: FormValue) => void
modelModeType: ModelModeType
}
export const useModelConfig = ({
initialConfig,
currModel,
hasFetchedDetail,
}: UseModelConfigParams): UseModelConfigReturn => {
const [modelConfig, setModelConfig] = useState<ModelConfig>({
provider: 'langgenius/openai/openai',
model_id: 'gpt-3.5-turbo',
mode: ModelModeType.unset,
// ... default values
...initialConfig,
})
const [completionParams, setCompletionParams] = useState<FormValue>({})
const modelModeType = modelConfig.mode
// Fill old app data missing model mode
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties?.mode
if (mode) {
setModelConfig(produce(modelConfig, (draft) => {
draft.mode = mode
}))
}
}
}, [hasFetchedDetail, modelModeType, currModel])
return {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
}
}
```
### Step 4: Update Component
```typescript
// Before: 50+ lines of state management
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
// ... lots of related state and effects
}
// After: Clean component
const Configuration: FC = () => {
const {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
} = useModelConfig({
currModel,
hasFetchedDetail,
})
// Component now focuses on UI
}
```
## Naming Conventions
### Hook Names
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
- Include domain: `useWorkflowVariables`, `useMCPServer`
### File Names
- Kebab-case: `use-model-config.ts`
- Place in `hooks/` subdirectory when multiple hooks exist
- Place alongside component for single-use hooks
### Return Type Names
- Suffix with `Return`: `UseModelConfigReturn`
- Suffix params with `Params`: `UseModelConfigParams`
## Common Hook Patterns in Dify
### 1. Data Fetching Hook (React Query)
```typescript
// Pattern: Use @tanstack/react-query for data fetching
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { get } from '@/service/base'
import { useInvalid } from '@/service/use-base'
const NAME_SPACE = 'appConfig'
// Query keys for cache management
export const appConfigQueryKeys = {
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
}
// Main data hook
export const useAppConfig = (appId: string) => {
return useQuery({
enabled: !!appId,
queryKey: appConfigQueryKeys.detail(appId),
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || null,
})
}
// Invalidation hook for refreshing data
export const useInvalidAppConfig = () => {
return useInvalid([NAME_SPACE])
}
// Usage in component
const Component = () => {
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
const invalidAppConfig = useInvalidAppConfig()
const handleRefresh = () => {
invalidAppConfig() // Invalidates cache and triggers refetch
}
return <div>...</div>
}
```
### 2. Form State Hook
```typescript
// Pattern: Form state + validation + submission
export const useConfigForm = (initialValues: ConfigFormValues) => {
const [values, setValues] = useState(initialValues)
const [errors, setErrors] = useState<Record<string, string>>({})
const [isSubmitting, setIsSubmitting] = useState(false)
const validate = useCallback(() => {
const newErrors: Record<string, string> = {}
if (!values.name) newErrors.name = 'Name is required'
setErrors(newErrors)
return Object.keys(newErrors).length === 0
}, [values])
const handleChange = useCallback((field: string, value: any) => {
setValues(prev => ({ ...prev, [field]: value }))
}, [])
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
if (!validate()) return
setIsSubmitting(true)
try {
await onSubmit(values)
} finally {
setIsSubmitting(false)
}
}, [values, validate])
return { values, errors, isSubmitting, handleChange, handleSubmit }
}
```
### 3. Modal State Hook
```typescript
// Pattern: Multiple modal management
type ModalType = 'edit' | 'delete' | 'duplicate' | null
export const useModalState = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const [modalData, setModalData] = useState<any>(null)
const openModal = useCallback((type: ModalType, data?: any) => {
setActiveModal(type)
setModalData(data)
}, [])
const closeModal = useCallback(() => {
setActiveModal(null)
setModalData(null)
}, [])
return {
activeModal,
modalData,
openModal,
closeModal,
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
}
}
```
### 4. Toggle/Boolean Hook
```typescript
// Pattern: Boolean state with convenience methods
export const useToggle = (initialValue = false) => {
const [value, setValue] = useState(initialValue)
const toggle = useCallback(() => setValue(v => !v), [])
const setTrue = useCallback(() => setValue(true), [])
const setFalse = useCallback(() => setValue(false), [])
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
}
// Usage
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
```
## Testing Extracted Hooks
After extraction, test hooks in isolation:
```typescript
// use-model-config.spec.ts
import { renderHook, act } from '@testing-library/react'
import { useModelConfig } from './use-model-config'
describe('useModelConfig', () => {
it('should initialize with default values', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: false,
}))
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
expect(result.current.modelModeType).toBe(ModelModeType.unset)
})
it('should update model config', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: true,
}))
act(() => {
result.current.setModelConfig({
...result.current.modelConfig,
model_id: 'gpt-4',
})
})
expect(result.current.modelConfig.model_id).toBe('gpt-4')
})
})
```

View File

@ -318,5 +318,5 @@ For more detailed information, refer to:
- `web/vitest.config.ts` - Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool
- `web/scripts/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -57,7 +57,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@v2.0.2
uses: hoverkraft-tech/compose-action@v2
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@ -12,12 +12,28 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Check Docker Compose inputs
id: docker-compose-changes
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v6
- uses: astral-sh/setup-uv@v7
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- run: |
cd api

View File

@ -90,7 +90,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*

View File

@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: "3.12"
@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: "3.12"

View File

@ -27,7 +27,7 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- uses: dorny/paths-filter@v3
id: changes
with:
@ -38,6 +38,7 @@ jobs:
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
- '.github/workflows/web-tests.yml'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'

View File

@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
uses: tj-actions/changed-files@v47
with:
files: |
api/**
@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: false
python-version: "3.12"
@ -68,15 +68,17 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
uses: tj-actions/changed-files@v47
with:
files: web/**
files: |
web/**
.github/workflows/style.yml
- name: Install pnpm
uses: pnpm/action-setup@v4
@ -85,7 +87,7 @@ jobs:
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
@ -108,50 +110,20 @@ jobs:
working-directory: ./web
run: pnpm run type-check:tsgo
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter:
name: SuperLinter
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
uses: tj-actions/changed-files@v47
with:
files: |
**.sh

View File

@ -25,12 +25,12 @@ jobs:
working-directory: sdks/nodejs-client
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: ${{ matrix.node-version }}
cache: ''

View File

@ -18,7 +18,7 @@ jobs:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@ -51,7 +51,7 @@ jobs:
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
cache: pnpm

View File

@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@v2
uses: endersonmenezes/free-disk-space@v3
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -18,7 +18,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
@ -29,7 +29,7 @@ jobs:
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: 22
cache: pnpm
@ -360,7 +360,7 @@ jobs:
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: web-coverage-report
path: web/coverage

View File

@ -1,34 +0,0 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -0,0 +1,57 @@
import os
from email.message import Message
from urllib.parse import quote
from flask import Response
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
HTML_EXTENSIONS = frozenset({"html", "htm"})
def _normalize_mime_type(mime_type: str | None) -> str:
if not mime_type:
return ""
message = Message()
message["Content-Type"] = mime_type
return message.get_content_type().strip().lower()
def _is_html_extension(extension: str | None) -> bool:
if not extension:
return False
return extension.lstrip(".").lower() in HTML_EXTENSIONS
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
normalized_mime_type = _normalize_mime_type(mime_type)
if normalized_mime_type in HTML_MIME_TYPES:
return True
if _is_html_extension(extension):
return True
if filename:
return _is_html_extension(os.path.splitext(filename)[1])
return False
def enforce_download_for_html(
response: Response,
*,
mime_type: str | None,
filename: str | None,
extension: str | None = None,
) -> bool:
if not is_html_content(mime_type, filename, extension):
return False
if filename:
encoded_filename = quote(filename)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
else:
response.headers["Content-Disposition"] = "attachment"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["X-Content-Type-Options"] = "nosniff"
return True

View File

@ -1,8 +1,9 @@
import base64
from typing import Literal
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: str = Field(..., description="Subscription plan")
interval: str = Field(..., description="Billing interval")
@field_validator("plan")
@classmethod
def validate_plan(cls, value: str) -> str:
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
raise ValueError("Invalid plan")
return value
@field_validator("interval")
@classmethod
def validate_interval(cls, value: str) -> str:
if value not in {"month", "year"}:
raise ValueError("Invalid interval")
return value
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
interval: Literal["month", "year"] = Field(..., description="Billing interval")
class PartnerTenantsPayload(BaseModel):

View File

@ -1,6 +1,5 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import marshal_with
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
@ -10,19 +8,19 @@ from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUID
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,6 +1,8 @@
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@ -10,10 +12,20 @@ from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
class LoadBalancingCredentialPayload(BaseModel):
model: str
model_type: ModelType
credentials: dict[str, object]
register_schema_models(console_ns, LoadBalancingCredentialPayload)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService()
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
model=payload.model,
model_type=payload.model_type,
credentials=payload.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService()
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
model=payload.model,
model_type=payload.model_type,
credentials=payload.credentials,
config_id=config_id,
)
except CredentialsValidateFailedError as ex:

View File

@ -1,5 +1,6 @@
import io
from typing import Literal
from collections.abc import Mapping
from typing import Any, Literal
from flask import request, send_file
from flask_restx import Resource
@ -141,6 +142,15 @@ class ParserDynamicOptions(BaseModel):
provider_type: Literal["tool", "trigger"]
class ParserDynamicOptionsWithCredentials(BaseModel):
plugin_id: str
provider: str
action: str
parameter: str
credential_id: str
credentials: Mapping[str, Any]
class PluginPermissionSettingsPayload(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
@ -183,6 +193,7 @@ reg(ParserGithubUpgrade)
reg(ParserUninstall)
reg(ParserPermissionChange)
reg(ParserDynamicOptions)
reg(ParserDynamicOptionsWithCredentials)
reg(ParserPreferencesChange)
reg(ParserExcludePlugin)
reg(ParserReadme)
@ -657,6 +668,37 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options-with-credentials")
class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
@console_ns.expect(console_ns.models[ParserDynamicOptionsWithCredentials.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
"""Fetch dynamic options using credentials directly (for edit mode)."""
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
try:
options = PluginParameterService.get_dynamic_select_options_with_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args.plugin_id,
provider=args.provider,
action=args.action,
parameter=args.parameter,
credential_id=args.credential_id,
credentials=args.credentials,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])

View File

@ -1,4 +1,5 @@
import io
import logging
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
@ -17,6 +18,7 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
@ -40,6 +42,8 @@ from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
logger = logging.getLogger(__name__)
def is_valid_url(url: str) -> bool:
if not url:
@ -945,8 +949,8 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider in transaction
with Session(db.engine) as session, session.begin():
# 1) Create provider in a short transaction (no network I/O inside)
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
@ -962,7 +966,28 @@ class ToolProviderMCPApi(Resource):
authentication=authentication,
)
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
# 2) Try to fetch tools immediately after creation so they appear without a second save.
# Perform network I/O outside any DB session to avoid holding locks.
try:
reconnect = MCPToolManageService.reconnect_with_url(
server_url=args["server_url"],
headers=args.get("headers") or {},
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
)
# Update just-created provider with authed/tools in a new short transaction
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
db_provider.authed = reconnect.authed
db_provider.tools = reconnect.tools
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
except Exception:
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
# Final cache invalidation to ensure list views are up to date
ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result)

View File

@ -1,11 +1,15 @@
import logging
from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -32,6 +36,32 @@ from ..wraps import (
logger = logging.getLogger(__name__)
class TriggerSubscriptionUpdateRequest(BaseModel):
"""Request payload for updating a trigger subscription"""
name: str | None = Field(default=None, description="The name for the subscription")
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
console_ns.schema_model(
TriggerSubscriptionUpdateRequest.__name__,
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
console_ns.schema_model(
TriggerSubscriptionVerifyRequest.__name__,
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
class TriggerProviderIconApi(Resource):
@setup_required
@ -155,16 +185,16 @@ parser_api = (
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -289,6 +319,83 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
raise ValueError(str(e)) from e
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
)
class TriggerSubscriptionUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, subscription_id: str):
"""Update a subscription instance"""
user = current_user
assert user.current_tenant_id is not None
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise NotFoundError(f"Subscription {subscription_id} not found")
provider_id = TriggerProviderID(subscription.provider_id)
try:
# rename only
if (
args.name is not None
and args.credentials is None
and args.parameters is None
and args.properties is None
):
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
)
return 200
# rebuild for create automatically by the provider
match subscription.credential_type:
case CredentialType.UNAUTHORIZED:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
properties=args.properties,
)
return 200
case CredentialType.API_KEY | CredentialType.OAUTH2:
if args.credentials:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in args.credentials.items()
}
else:
new_credentials = subscription.credentials
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=args.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=new_credentials,
parameters=args.parameters or subscription.parameters,
)
return 200
case _:
raise BadRequest("Invalid credential type")
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error updating subscription", exc_info=e)
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
@ -576,3 +683,38 @@ class TriggerOAuthClientManageApi(Resource):
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
)
class TriggerSubscriptionVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_id):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
console_ns.payload
)
try:
result = TriggerProviderService.verify_subscription_credentials(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_id=subscription_id,
credentials=verify_request.credentials,
)
return result
except ValueError as e:
logger.warning("Credential verification failed", exc_info=e)
raise BadRequest(str(e)) from e
except Exception as e:
logger.exception("Error verifying subscription credentials", exc_info=e)
raise BadRequest(str(e)) from e

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import UnsupportedFileTypeError
from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from extensions.ext_database import db
from services.account_service import TenantService
@ -138,6 +139,13 @@ class FilePreviewApi(Resource):
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,
mime_type=upload_file.mime_type,
filename=upload_file.name,
extension=upload_file.extension,
)
return response

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
@ -78,4 +79,11 @@ class ToolFileApi(Resource):
encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
enforce_download_for_html(
response,
mime_type=tool_file.mimetype,
filename=tool_file.name,
extension=extension,
)
return response

View File

@ -5,6 +5,7 @@ from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.file_response import enforce_download_for_html
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@ -183,6 +184,13 @@ class FilePreviewApi(Resource):
# Override content-type for downloads to force download
response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,
mime_type=upload_file.mime_type,
filename=upload_file.name,
extension=upload_file.extension,
)
# Add caching headers for performance
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour

View File

@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@validate_dataset_token
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _, dataset_id):
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token
def patch(self, _, dataset_id):
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
@edit_permission_required
def delete(self, _, dataset_id):
def delete(self, _):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")

View File

@ -90,6 +90,7 @@ class AppQueueManager:
"""
self._clear_task_belong_cache()
self._q.put(None)
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
def _clear_task_belong_cache(self) -> None:
"""

View File

@ -1,9 +1,14 @@
from collections.abc import Mapping
from textwrap import dedent
from typing import Any
from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer):
# Use separate placeholder for base64-encoded template to avoid confusion
_template_b64_placeholder: str = "{{template_b64}}"
@classmethod
def transform_response(cls, response: str):
"""
@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer):
"""
return {"result": cls.extract_result_str_from_response(response)}
@classmethod
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
"""
Override base class to use base64 encoding for template code.
This prevents issues with special characters (quotes, newlines) in templates
breaking the generated Python script. Fixes #26818.
"""
script = cls.get_runner_script()
# Encode template as base64 to safely embed any content including quotes
code_b64 = cls.serialize_code(code)
script = script.replace(cls._template_b64_placeholder, code_b64)
inputs_str = cls.serialize_inputs(inputs)
script = script.replace(cls._inputs_placeholder, inputs_str)
return script
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
# declare main function
def main(**inputs):
import jinja2
template = jinja2.Template('''{cls._code_placeholder}''')
return template.render(**inputs)
import jinja2
import json
from base64 import b64decode
# declare main function
def main(**inputs):
# Decode base64-encoded template to handle special characters safely
template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
template = jinja2.Template(template_code)
return template.render(**inputs)
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))

View File

@ -13,6 +13,15 @@ class TemplateTransformer(ABC):
_inputs_placeholder: str = "{{inputs}}"
_result_tag: str = "<<RESULT>>"
@classmethod
def serialize_code(cls, code: str) -> str:
"""
Serialize template code to base64 to safely embed in generated script.
This prevents issues with special characters like quotes breaking the script.
"""
code_bytes = code.encode("utf-8")
return b64encode(code_bytes).decode("utf-8")
@classmethod
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
"""

View File

@ -118,13 +118,11 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
request = client.build_request(method=method, url=url, **kwargs)
# If user explicitly provided a Host header, ensure it's preserved
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
request.headers["Host"] = user_provided_host
response = client.send(request)
headers["host"] = user_provided_host
kwargs["headers"] = headers
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any
from typing import Any, cast
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
@ -50,7 +50,9 @@ class ToolProviderListCache:
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
keys = list(redis_client.scan_iter(pattern))
if keys:
redis_client.delete(*keys)
keys = ["builtin", "model", "api", "workflow", "mcp"]
pipeline = redis_client.pipeline()
for key in keys:
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
pipeline.delete(cache_key)
pipeline.execute()

View File

@ -313,17 +313,20 @@ class StreamableHTTPTransport:
if is_initialization:
self._maybe_extract_session_id_from_response(response)
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
def _handle_json_response(
self,

View File

@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None
required: bool = False
default: Union[float, int, str, bool] | None = None
default: Union[float, int, str, bool, list, dict] | None = None
min: Union[float, int] | None = None
max: Union[float, int] | None = None
precision: int | None = None

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -381,10 +381,9 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
segment_file_map = {}
valid_dataset_documents = {}
image_doc_ids = []
image_doc_ids: list[Any] = []
child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
@ -417,28 +416,39 @@ class RetrievalService:
child_index_node_ids = [i for i in child_index_node_ids if i]
index_node_ids = [i for i in index_node_ids if i]
segment_ids = []
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map = {}
child_chunk_map = {}
doc_segment_map = {}
attachment_map: dict[str, list[dict[str, Any]]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
for attachment in attachments:
segment_ids.append(attachment["segment_id"])
attachment_map[attachment["segment_id"]] = attachment
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
if attachment["segment_id"] in attachment_map:
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
else:
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
if attachment["segment_id"] in doc_segment_map:
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
else:
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
segment_ids.append(i.segment_id)
child_chunk_map[i.segment_id] = i
doc_segment_map[i.segment_id] = i.index_node_id
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)
else:
child_chunk_map[i.segment_id] = [i]
if i.segment_id in doc_segment_map:
doc_segment_map[i.segment_id].append(i.index_node_id)
else:
doc_segment_map[i.segment_id] = [i.index_node_id]
if index_node_ids:
document_segment_stmt = select(DocumentSegment).where(
@ -448,7 +458,7 @@ class RetrievalService:
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
@ -461,95 +471,86 @@ class RetrievalService:
segments.extend(index_node_segments)
for segment in segments:
doc_id = doc_segment_map.get(segment.id)
child_chunk = child_chunk_map.get(segment.id)
attachment_info = attachment_map.get(segment.id)
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if doc_id:
document = doc_to_document_map[doc_id]
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
document.metadata.get("document_id")
)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunks or attachment_infos:
child_chunk_details = []
max_score = 0.0
for child_chunk in child_chunks:
document = doc_to_document_map[child_chunk.index_node_id]
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0) if document else 0.0,
}
map_detail = {
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
child_chunk_details.append(child_chunk_detail)
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
for attachment_info in attachment_infos:
file_document = doc_to_document_map[attachment_info["id"]]
max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
map_detail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"],
document.metadata.get("score", 0.0) if document else 0.0,
)
else:
segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail],
}
if attachment_info:
if segment.id in segment_file_map:
segment_file_map[segment.id].append(attachment_info)
else:
segment_file_map[segment.id] = [attachment_info]
else:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", 0.0), # type: ignore
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if attachment_info:
attachment_infos = segment_file_map.get(segment.id, [])
if attachment_info not in attachment_infos:
attachment_infos.append(attachment_info)
segment_file_map[segment.id] = attachment_infos
segment_child_map[segment.id] = map_detail
record: dict[str, Any] = {
"segment": segment,
}
records.append(record)
else:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
max_score = 0.0
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
for attachment_info in attachment_infos:
file_doc = doc_to_document_map.get(attachment_info["id"])
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
"segment": segment,
"score": max_score,
}
records.append(record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
result = []
result: list[RetrievalSegments] = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None
raw_child_chunks = record.get("child_chunks")
child_chunks_list: list[RetrievalChildChunk] | None = None
if isinstance(raw_child_chunks, list):
# Sort by score descending
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
child_chunks_list = [
RetrievalChildChunk(
id=chunk["id"],
content=chunk["content"],
score=chunk.get("score", 0.0),
position=chunk["position"],
)
for chunk in sorted_chunks
]
# Extract files, ensuring it's a list or None
files = record.get("files")
@ -566,11 +567,11 @@ class RetrievalService:
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files
segment=segment, child_chunks=child_chunks_list, score=score, files=files
)
result.append(retrieval_segment)
return result
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
except Exception as e:
db.session.rollback()
raise e

View File

@ -255,7 +255,10 @@ class PGVector(BaseVector):
return
with self._get_cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
if not cur.fetchone():
cur.execute("CREATE EXTENSION vector")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing

View File

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@ -1036,7 +1036,7 @@ class DatasetRetrieval:
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
self.process_metadata_filter_func(
sequence,
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
@ -1072,7 +1072,7 @@ class DatasetRetrieval:
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
filters = self.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -1168,8 +1168,9 @@ class DatasetRetrieval:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
@classmethod
def process_metadata_filter_func(
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return filters
@ -1218,6 +1219,20 @@ class DatasetRetrieval:
case "" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case "in" | "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
# `field in []` is False, `field not in []` is True
filters.append(literal(condition == "not in"))
else:
op = json_field.in_ if condition == "in" else json_field.notin_
filters.append(op(value_list))
case _:
pass

View File

@ -153,11 +153,11 @@ class ToolInvokeMessage(BaseModel):
@classmethod
def transform_variable_value(cls, values):
"""
Only basic types and lists are allowed.
Only basic types, lists, and None are allowed.
"""
value = values.get("variable_value")
if not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types and lists are allowed.")
if value is not None and not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types, lists, and None are allowed.")
# if stream is true, the value must be a string
if values.get("stream"):

View File

@ -6,7 +6,15 @@ from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
)
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@ -53,10 +61,19 @@ class MCPTool(Tool):
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
elif isinstance(content, AudioContent):
yield self._process_audio_content(content)
elif isinstance(content, ImageContent | AudioContent):
yield self.create_blob_message(
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
)
elif isinstance(content, EmbeddedResource):
resource = content.resource
if isinstance(resource, TextResourceContents):
yield self.create_text_message(resource.text)
elif isinstance(resource, BlobResourceContents):
mime_type = resource.mimeType or "application/octet-stream"
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
else:
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
else:
logger.warning("Unsupported content type=%s", type(content))
@ -101,14 +118,6 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -47,33 +48,30 @@ class WorkflowToolProviderController(ToolProviderController):
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
with Session(db.engine, expire_on_commit=False) as session, session.begin():
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
if not provider:
raise ValueError("workflow provider not found")
app = session.get(App, provider.app_id)
with session_factory.create_session() as session, session.begin():
app = session.get(App, db_provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, provider.user_id) if provider.user_id else None
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user.name if user else "",
name=provider.label,
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
icon=provider.icon,
name=db_provider.label,
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
),
credentials_schema=[],
plugin_id=None,
),
provider_id=provider.id or "",
provider_id="",
)
controller.tools = [
controller._get_db_provider_tool(provider, app, session=session, user=user),
controller._get_db_provider_tool(db_provider, app, session=session, user=user),
]
return controller

View File

@ -67,12 +67,16 @@ def create_trigger_provider_encrypter_for_subscription(
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
cache = TriggerProviderCredentialsCache(
TriggerProviderCredentialsCache(
tenant_id=tenant_id,
provider_id=provider_id,
credential_id=subscription_id,
)
cache.delete()
).delete()
TriggerProviderPropertiesCache(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_id=subscription_id,
).delete()
def create_trigger_provider_encrypter_for_properties(

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return filters
json_field = Document.doc_metadata[metadata_name].as_string()
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(json_field.like(f"%{value}"))
case "in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(False))
else:
filters.append(json_field.in_(value_list))
case "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(True))
else:
filters.append(json_field.notin_(value_list))
case "is" | "=":
if isinstance(value, str):
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -281,7 +281,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# handle invoke result
text = invoke_result.message.content or ""
text = invoke_result.message.get_text_content()
if not isinstance(text, str):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.1"
version = "1.11.2"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -14,7 +14,8 @@ from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService

View File

@ -21,7 +21,7 @@ from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -141,7 +141,7 @@ class AsyncWorkflowService:
trigger_log_repo.update(trigger_log)
session.commit()
raise InvokeRateLimitError(
raise WorkflowQuotaLimitError(
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
) from e

View File

@ -3458,7 +3458,7 @@ class SegmentService:
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.order_by(DocumentSegment.position.asc())
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total

View File

@ -110,5 +110,5 @@ class EnterpriseService:
if not app_id:
raise ValueError("app_id must be provided.")
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)

View File

@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
pass
class InvokeRateLimitError(Exception):
"""Raised when rate limit is exceeded for workflow invocations."""
class WorkflowQuotaLimitError(Exception):
"""Raised when workflow execution quota is exceeded (for async/background workflows)."""
pass

View File

@ -105,3 +105,49 @@ class PluginParameterService:
)
.options
)
@staticmethod
def get_dynamic_select_options_with_credentials(
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
action: str,
parameter: str,
credential_id: str,
credentials: Mapping[str, Any],
) -> Sequence[PluginParameterOption]:
"""
Get dynamic select options using provided credentials directly.
Used for edit mode when credentials have been modified but not yet saved.
Security: credential_id is validated against tenant_id to ensure
users can only access their own credentials.
"""
from constants import HIDDEN_VALUE
# Get original subscription to replace hidden values (with tenant_id check for security)
original_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
if not original_subscription:
raise ValueError(f"Subscription {credential_id} not found")
# Replace [__HIDDEN__] with original values
resolved_credentials: dict[str, Any] = {
key: (original_subscription.credentials.get(key) if value == HIDDEN_VALUE else value)
for key, value in credentials.items()
}
return (
DynamicSelectClient()
.fetch_dynamic_select_options(
tenant_id,
user_id,
plugin_id,
provider,
action,
resolved_credentials,
original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
parameter,
)
.options
)

View File

@ -286,12 +286,12 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
return {"result": "success"}
@staticmethod

View File

@ -319,8 +319,14 @@ class MCPToolManageService:
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
# Update database with retrieved tools (ensure description is a non-null string)
tools_payload = []
for tool in tools:
data = tool.model_dump()
if data.get("description") is None:
data["description"] = ""
tools_payload.append(data)
db_provider.tools = json.dumps(tools_payload)
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
@ -620,6 +626,21 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash,
)
@staticmethod
def reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
return MCPToolManageService._reconnect_with_url(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
@staticmethod
def _reconnect_with_url(
*,
@ -642,9 +663,16 @@ class MCPToolManageService:
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
# Ensure tool descriptions are non-null in payload
tools_payload = []
for t in tools:
d = t.model_dump()
if d.get("description") is None:
d["description"] = ""
tools_payload.append(d)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
tools=json.dumps(tools_payload),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:

View File

@ -5,8 +5,8 @@ from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@ -68,26 +68,27 @@ class WorkflowToolManageService:
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
session.add(workflow_tool_provider)
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
with session_factory.create_session() as session, session.begin():
session.add(workflow_tool_provider)
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels

View File

@ -94,16 +94,23 @@ class TriggerProviderService:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
for subscription in subscriptions:
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.credentials = dict(
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
credential_encrypter.mask_credentials(dict(credential_encrypter.decrypt(subscription.credentials)))
)
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.properties = dict(
properties_encrypter.mask_credentials(dict(properties_encrypter.decrypt(subscription.properties)))
)
subscription.parameters = dict(subscription.parameters)
count = workflows_in_use_map.get(subscription.id)
subscription.workflows_in_use = count if count is not None else 0
@ -209,6 +216,101 @@ class TriggerProviderService:
logger.exception("Failed to add trigger provider")
raise ValueError(str(e))
@classmethod
def update_trigger_subscription(
cls,
tenant_id: str,
subscription_id: str,
name: str | None = None,
properties: Mapping[str, Any] | None = None,
parameters: Mapping[str, Any] | None = None,
credentials: Mapping[str, Any] | None = None,
credential_expires_at: int | None = None,
expires_at: int | None = None,
) -> None:
"""
Update an existing trigger subscription.
:param tenant_id: Tenant ID
:param subscription_id: Subscription instance ID
:param name: Optional new name for this subscription
:param properties: Optional new properties
:param parameters: Optional new parameters
:param credentials: Optional new credentials
:param credential_expires_at: Optional new credential expiration timestamp
:param expires_at: Optional new expiration timestamp
:return: Success response with updated subscription info
"""
with Session(db.engine, expire_on_commit=False) as session:
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Trigger subscription {subscription_id} not found")
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
# Check for name uniqueness if name is being updated
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update properties if provided
if properties is not None:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
# Handle hidden values - preserve original encrypted values
original_properties = properties_encrypter.decrypt(subscription.properties)
new_properties: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else original_properties.get(key, UNKNOWN_VALUE)
for key, value in properties.items()
}
subscription.properties = dict(properties_encrypter.encrypt(new_properties))
# Update parameters if provided
if parameters is not None:
subscription.parameters = dict(parameters)
# Update credentials if provided
if credentials is not None:
credential_type = CredentialType.of(subscription.credential_type)
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
subscription.credentials = dict(credential_encrypter.encrypt(dict(credentials)))
# Update credential expiration timestamp if provided
if credential_expires_at is not None:
subscription.credential_expires_at = credential_expires_at
# Update expiration timestamp if provided
if expires_at is not None:
subscription.expires_at = expires_at
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
@classmethod
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
"""
@ -257,17 +359,18 @@ class TriggerProviderService:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
if is_auto_created:
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
@ -280,8 +383,8 @@ class TriggerProviderService:
except Exception as e:
logger.exception("Error unsubscribing trigger", exc_info=e)
# Clear cache
session.delete(subscription)
# Clear cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
@ -688,3 +791,188 @@ class TriggerProviderService:
)
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
return subscription
@classmethod
def verify_subscription_credentials(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_id: str,
credentials: Mapping[str, Any],
) -> dict[str, Any]:
"""
Verify credentials for an existing subscription without updating it.
This is used in edit mode to validate new credentials before rebuild.
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider identifier
:param subscription_id: Subscription ID
:param credentials: New credentials to verify
:return: dict with 'verified' boolean
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription = cls.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
# For API Key, validate the new credentials
if credential_type == CredentialType.API_KEY:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
try:
provider_controller.validate_credentials(user_id, credentials=new_credentials)
return {"verified": True}
except Exception as e:
raise ValueError(f"Invalid credentials: {e}") from e
return {"verified": True}
@classmethod
def rebuild_trigger_subscription(
cls,
tenant_id: str,
provider_id: TriggerProviderID,
subscription_id: str,
credentials: Mapping[str, Any],
parameters: Mapping[str, Any],
name: str | None = None,
) -> None:
"""
Create a subscription builder for rebuilding an existing subscription.
This method creates a builder pre-filled with data from the rebuild request,
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
:param tenant_id: Tenant ID
:param name: Name for the subscription
:param subscription_id: Subscription ID
:param provider_id: Provider identifier
:param credentials: Credentials for the subscription
:param parameters: Parameters for the subscription
:return: SubscriptionBuilderApiEntity
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
with Session(db.engine, expire_on_commit=False) as session:
try:
# Get subscription within the transaction
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
# Decrypt existing credentials for merging
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
merged_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
user_id = subscription.user_id
# TODO: Trying to invoke update api of the plugin trigger provider
# FALLBACK: If the update api is not implemented,
# delete the previous subscription and create a new one
# Unsubscribe the previous subscription (external call, but we'll handle errors)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=decrypted_credentials,
credential_type=credential_type,
)
except Exception as e:
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
# Continue anyway - the subscription might already be deleted externally
# Create a new subscription with the same subscription_id and endpoint_id (external call)
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=merged_credentials,
credential_type=credential_type,
)
# Update the subscription in the same transaction
# Inline update logic to reuse the same session
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing and existing.id != subscription.id:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update parameters
subscription.parameters = dict(parameters)
# Update credentials with merged (and encrypted) values
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
# Update properties
if new_subscription.properties:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
# Update expiration timestamp
if new_subscription.expires_at is not None:
subscription.expires_at = new_subscription.expires_at
# Commit the transaction
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
except Exception as e:
# Rollback on any error
session.rollback()
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
raise

View File

@ -453,11 +453,12 @@ class TriggerSubscriptionBuilderService:
if not subscription_builder:
return None
# response to validation endpoint
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
)
try:
# response to validation endpoint
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription_builder.tenant_id,
provider_id=TriggerProviderID(subscription_builder.provider_id),
)
dispatch_response: TriggerDispatchResponse = controller.dispatch(
request=request,
subscription=subscription_builder.to_subscription(),

View File

@ -863,10 +863,18 @@ class WebhookService:
not_found_in_cache.append(node_id)
continue
with Session(db.engine) as session:
try:
# lock the concurrent webhook trigger creation
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
lock = redis_client.lock(lock_key, timeout=10)
lock_acquired = False
try:
# acquire the lock with blocking and timeout
lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
if not lock_acquired:
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
with Session(db.engine) as session:
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
@ -903,11 +911,16 @@ class WebhookService:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
# release the lock only if it was acquired
if lock_acquired:
try:
lock.release()
except Exception:
logger.exception("Failed to release lock for webhook sync, app %s", app.id)
@classmethod
def generate_webhook_id(cls) -> str:

View File

@ -7,11 +7,14 @@ CODE_LANGUAGE = CodeLanguage.JINJA2
def test_jinja2():
"""Test basic Jinja2 template rendering."""
template = "Hello {{template}}"
# Template must be base64 encoded to match the new safe embedding approach
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
code = (
Jinja2TemplateTransformer.get_runner_script()
.replace(Jinja2TemplateTransformer._code_placeholder, template)
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
)
result = CodeExecutor.execute_code(
@ -21,6 +24,7 @@ def test_jinja2():
def test_jinja2_with_code_template():
"""Test template rendering via the high-level workflow API."""
result = CodeExecutor.execute_workflow_code_template(
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
)
@ -28,7 +32,64 @@ def test_jinja2_with_code_template():
def test_jinja2_get_runner_script():
"""Test that runner script contains required placeholders."""
runner_script = Jinja2TemplateTransformer.get_runner_script()
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
def test_jinja2_template_with_special_characters():
"""
Test that templates with special characters (quotes, newlines) render correctly.
This is a regression test for issue #26818 where textarea pre-fill values
containing special characters would break template rendering.
"""
# Template with triple quotes, single quotes, double quotes, and newlines
template = """<html>
<body>
<input value="{{ task.get('Task ID', '') }}"/>
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
<p>Status: "{{ status }}"</p>
<pre>'''code block'''</pre>
</body>
</html>"""
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
# Verify the template rendered correctly with all special characters
output = result["result"]
assert 'value="TASK-123"' in output
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
assert 'Status: "completed"' in output
assert "'''code block'''" in output
def test_jinja2_template_with_html_textarea_prefill():
"""
Specific test for HTML textarea with Jinja2 variable pre-fill.
Verifies fix for issue #26818.
"""
template = "<textarea name='notes'>{{ notes }}</textarea>"
notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes."
inputs = {"notes": notes_content}
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
expected_output = f"<textarea name='notes'>{notes_content}</textarea>"
assert result["result"] == expected_output
def test_jinja2_assemble_runner_script_encodes_template():
"""Test that assemble_runner_script properly base64 encodes the template."""
template = "Hello {{ name }}!"
inputs = {"name": "World"}
script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs)
# The template should be base64 encoded in the script
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
assert template_b64 in script
# The raw template should NOT appear in the script (it's encoded)
assert "Hello {{ name }}!" not in script

View File

@ -0,0 +1,682 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
from extensions.ext_database import db
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
class TestTriggerProviderService:
"""Integration tests for TriggerProviderService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager,
patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client,
patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns
mock_provider_controller = MagicMock()
mock_provider_controller.get_credential_schema_config.return_value = MagicMock()
mock_provider_controller.get_properties_schema.return_value = MagicMock()
mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller
# Mock redis lock
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock(return_value=None)
mock_lock.__exit__ = MagicMock(return_value=None)
mock_redis_client.lock.return_value = mock_lock
# Setup account feature service mock
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
yield {
"trigger_manager": mock_trigger_manager,
"redis_client": mock_redis_client,
"delete_cache": mock_delete_cache,
"provider_controller": mock_provider_controller,
"account_feature_service": mock_account_feature_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
from services.account_service import AccountService, TenantService
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies[
"trigger_manager"
].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"]
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
return account, tenant
def _create_test_subscription(
self,
db_session_with_containers,
tenant_id,
user_id,
provider_id,
credential_type,
credentials,
mock_external_service_dependencies,
):
"""
Helper method to create a test trigger subscription.
Args:
db_session_with_containers: Database session
tenant_id: Tenant ID
user_id: User ID
provider_id: Provider ID
credential_type: Credential type
credentials: Credentials dict
mock_external_service_dependencies: Mock dependencies
Returns:
TriggerSubscription: Created subscription instance
"""
fake = Faker()
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
# Use mock provider controller to encrypt credentials
provider_controller = mock_external_service_dependencies["provider_controller"]
# Create encrypter for credentials
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
subscription = TriggerSubscription(
name=fake.word(),
tenant_id=tenant_id,
user_id=user_id,
provider_id=str(provider_id),
endpoint_id=fake.uuid4(),
parameters={"param1": "value1"},
properties={"prop1": "value1"},
credentials=dict(credential_encrypter.encrypt(credentials)),
credential_type=credential_type.value,
credential_expires_at=-1,
expires_at=-1,
)
db.session.add(subscription)
db.session.commit()
db.session.refresh(subscription)
return subscription
def test_rebuild_trigger_subscription_success_with_merged_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
This test verifies:
- Credentials are properly merged (HIDDEN_VALUE replaced with existing values)
- Single transaction wraps all operations
- Merged credentials are used for subscribe and update
- Database state is correctly updated
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create initial subscription with credentials
original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Prepare new credentials with HIDDEN_VALUE for api_key (should keep original)
# and new value for api_secret (should update)
new_credentials = {
"api_key": HIDDEN_VALUE, # Should be replaced with original
"api_secret": "new-secret-value", # Should be updated
}
# Mock subscribe_trigger to return a new subscription entity
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={"param1": "value1"},
properties={"prop1": "new_prop_value"},
expires_at=1234567890,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Mock unsubscribe_trigger
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={"param1": "updated_value"},
name="updated_name",
)
# Verify unsubscribe was called with decrypted original credentials
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once()
unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args
assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id
assert unsubscribe_call_args.kwargs["provider_id"] == provider_id
assert unsubscribe_call_args.kwargs["credential_type"] == credential_type
# Verify subscribe was called with merged credentials (api_key from original, api_secret new)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
# Verify database state was updated
db.session.refresh(subscription)
assert subscription.name == "updated_name"
assert subscription.parameters == {"param1": "updated_value"}
# Verify credentials in DB were updated with merged values (decrypt to check)
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
# Use mock provider controller to decrypt credentials
provider_controller = mock_external_service_dependencies["provider_controller"]
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant.id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
assert decrypted_db_credentials["api_key"] == original_credentials["api_key"]
assert decrypted_db_credentials["api_secret"] == "new-secret-value"
# Verify cache was cleared
mock_external_service_dependencies["delete_cache"].assert_called_once_with(
tenant_id=tenant.id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
def test_rebuild_trigger_subscription_with_all_new_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are new (no HIDDEN_VALUE).
This test verifies:
- All new credentials are used when no HIDDEN_VALUE is present
- Merged credentials contain only new values
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create initial subscription
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# All new credentials (no HIDDEN_VALUE)
new_credentials = {
"api_key": "completely-new-key",
"api_secret": "completely-new-secret",
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with all new credentials
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == "completely-new-key"
assert subscribe_credentials["api_secret"] == "completely-new-secret"
def test_rebuild_trigger_subscription_with_all_hidden_values(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
This test verifies:
- All HIDDEN_VALUE credentials are replaced with existing values
- Original credentials are preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# All HIDDEN_VALUE (should preserve all original)
new_credentials = {
"api_key": HIDDEN_VALUE,
"api_secret": HIDDEN_VALUE,
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with all original credentials
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
This test verifies:
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Original has only api_key
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE
new_credentials = {
"api_key": HIDDEN_VALUE,
"non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
def test_rebuild_trigger_subscription_rollback_on_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that transaction is rolled back on error.
This test verifies:
- Database transaction is rolled back when an error occurs
- Original subscription state is preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
original_name = subscription.name
original_parameters = subscription.parameters.copy()
# Make subscribe_trigger raise an error
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError(
"Subscribe failed"
)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild and expect error
with pytest.raises(ValueError, match="Subscribe failed"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscription state was not changed (rolled back)
db.session.refresh(subscription)
assert subscription.name == original_name
assert subscription.parameters == original_parameters
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that unsubscribe errors are handled gracefully and operation continues.
This test verifies:
- Unsubscribe errors are caught and logged but don't stop the rebuild
- Rebuild continues even if unsubscribe fails
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Make unsubscribe_trigger raise an error (should be caught and continue)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
"Unsubscribe failed"
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Execute rebuild - should succeed despite unsubscribe error
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscribe was still called (operation continued)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
# Verify subscription was updated
db.session.refresh(subscription)
assert subscription.parameters == {}
def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when subscription is not found.
This test verifies:
- Proper error is raised when subscription doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
fake_subscription_id = fake.uuid4()
with pytest.raises(ValueError, match="not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake_subscription_id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when provider is not found.
This test verifies:
- Proper error is raised when provider doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
# Make get_trigger_provider return None
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
with pytest.raises(ValueError, match="Provider.*not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake.uuid4(),
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_unsupported_credential_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when credential type is not supported for rebuild.
This test verifies:
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.UNAUTHORIZED # Not supported
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{},
mock_external_service_dependencies,
)
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_name_uniqueness_check(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that name uniqueness is checked when updating name.
This test verifies:
- Error is raised when new name conflicts with existing subscription
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create first subscription
subscription1 = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{"api_key": "key1"},
mock_external_service_dependencies,
)
# Create second subscription with different name
subscription2 = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{"api_key": "key2"},
mock_external_service_dependencies,
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription2.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Try to rename subscription2 to subscription1's name (should fail)
with pytest.raises(ValueError, match="already exists"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription2.id,
credentials={"api_key": "new-key"},
parameters={},
name=subscription1.name, # Conflicting name
)

View File

@ -705,3 +705,207 @@ class TestWorkflowToolManageService:
db.session.refresh(created_tool)
assert created_tool.name == first_tool_name
assert created_tool.updated_at is not None
def test_create_workflow_tool_with_file_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILE parameter having a file object as default.
This test verifies:
- FILE parameters can have file object defaults
- The default value (dict with id/base64Url) is properly handled
- Tool creation succeeds without Pydantic validation errors
Related issue: Array[File] default value causes Pydantic validation errors.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create workflow graph with a FILE variable that has a default value
workflow_graph = {
"nodes": [
{
"id": "start_node",
"data": {
"type": "start",
"variables": [
{
"variable": "document",
"label": "Document",
"type": "file",
"required": False,
"default": {"id": fake.uuid4(), "base64Url": ""},
}
],
},
}
]
}
workflow.graph = json.dumps(workflow_graph)
# Setup workflow tool parameters with FILE type
file_parameters = [
{
"name": "document",
"description": "Upload a document",
"form": "form",
"type": "file",
"required": False,
}
]
# Execute the method under test
# Note: from_db is mocked, so this test primarily validates the parameter configuration
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "📄"},
description=fake.text(max_nb_chars=200),
parameters=file_parameters,
)
# Verify the result
assert result == {"result": "success"}
def test_create_workflow_tool_with_files_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
This test verifies:
- FILES parameters can have a list of file objects as default
- The default value (list of dicts with id/base64Url) is properly handled
- Tool creation succeeds without Pydantic validation errors
Related issue: Array[File] default value causes 4 Pydantic validation errors
because PluginParameter.default only accepts Union[float, int, str, bool] | None.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create workflow graph with a FILE_LIST variable that has a default value
workflow_graph = {
"nodes": [
{
"id": "start_node",
"data": {
"type": "start",
"variables": [
{
"variable": "documents",
"label": "Documents",
"type": "file-list",
"required": False,
"default": [
{"id": fake.uuid4(), "base64Url": ""},
{"id": fake.uuid4(), "base64Url": ""},
],
}
],
},
}
]
}
workflow.graph = json.dumps(workflow_graph)
# Setup workflow tool parameters with FILES type
files_parameters = [
{
"name": "documents",
"description": "Upload multiple documents",
"form": "form",
"type": "files",
"required": False,
}
]
# Execute the method under test
# Note: from_db is mocked, so this test primarily validates the parameter configuration
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "📁"},
description=fake.text(max_nb_chars=200),
parameters=files_parameters,
)
# Verify the result
assert result == {"result": "success"}
def test_create_workflow_tool_db_commit_before_validation(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that database commit happens before validation, causing DB pollution on validation failure.
This test verifies the second bug:
- WorkflowToolProvider is committed to database BEFORE from_db validation
- If validation fails, the record remains in the database
- Subsequent attempts fail with "Tool already exists" error
This demonstrates why we need to validate BEFORE database commit.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
tool_name = fake.word()
# Mock from_db to raise validation error
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.side_effect = ValueError(
"Validation failed: default parameter type mismatch"
)
# Attempt to create workflow tool (will fail at validation stage)
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
)
assert "Validation failed" in str(exc_info.value)
# Verify the tool was NOT created in database
# This is the expected behavior (no pollution)
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.name == tool_name,
)
.count()
)
# The record should NOT exist because the transaction should be rolled back
# Currently, due to the bug, the record might exist (this test documents the bug)
# After the fix, this should always be 0
# For now, we document that the record may exist, demonstrating the bug
# assert tool_count == 0 # Expected after fix

View File

@ -12,10 +12,12 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
_, Jinja2TemplateTransformer = self.jinja2_imports
template = "Hello {{template}}"
# Template must be base64 encoded to match the new safe embedding approach
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
code = (
Jinja2TemplateTransformer.get_runner_script()
.replace(Jinja2TemplateTransformer._code_placeholder, template)
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
)
result = CodeExecutor.execute_code(
@ -37,6 +39,34 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
_, Jinja2TemplateTransformer = self.jinja2_imports
runner_script = Jinja2TemplateTransformer.get_runner_script()
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
def test_jinja2_template_with_special_characters(self, flask_app_with_containers):
"""
Test that templates with special characters (quotes, newlines) render correctly.
This is a regression test for issue #26818 where textarea pre-fill values
containing special characters would break template rendering.
"""
CodeExecutor, CodeLanguage = self.code_executor_imports
# Template with triple quotes, single quotes, double quotes, and newlines
template = """<html>
<body>
<input value="{{ task.get('Task ID', '') }}"/>
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
<p>Status: "{{ status }}"</p>
<pre>'''code block'''</pre>
</body>
</html>"""
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
# Verify the template rendered correctly with all special characters
output = result["result"]
assert 'value="TASK-123"' in output
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
assert 'Status: "completed"' in output
assert "'''code block'''" in output

View File

@ -0,0 +1,46 @@
from flask import Response
from controllers.common.file_response import enforce_download_for_html, is_html_content
class TestFileResponseHelpers:
def test_is_html_content_detects_mime_type(self):
mime_type = "text/html; charset=UTF-8"
result = is_html_content(mime_type, filename="file.txt", extension="txt")
assert result is True
def test_is_html_content_detects_extension(self):
result = is_html_content("text/plain", filename="report.html", extension=None)
assert result is True
def test_enforce_download_for_html_sets_headers(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
response,
mime_type="text/html",
filename="unsafe.html",
extension="html",
)
assert updated is True
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_enforce_download_for_html_no_change_for_non_html(self):
response = Response("payload", mimetype="text/plain")
updated = enforce_download_for_html(
response,
mime_type="text/plain",
filename="notes.txt",
extension="txt",
)
assert updated is False
assert "Content-Disposition" not in response.headers
assert "X-Content-Type-Options" not in response.headers

View File

@ -0,0 +1,145 @@
"""Unit tests for load balancing credential validation APIs."""
from __future__ import annotations
import builtins
import importlib
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from flask.views import MethodView
from werkzeug.exceptions import Forbidden
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
from models.account import TenantAccountRole
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
"""Reload controller module with lightweight decorators for testing."""
from controllers.console import console_ns, wraps
from libs import login
def _noop(func):
return func
monkeypatch.setattr(login, "login_required", _noop)
monkeypatch.setattr(wraps, "setup_required", _noop)
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
def _noop_route(*args, **kwargs): # type: ignore[override]
def _decorator(cls):
return cls
return _decorator
monkeypatch.setattr(console_ns, "route", _noop_route)
module_name = "controllers.console.workspace.load_balancing_config"
sys.modules.pop(module_name, None)
module = importlib.import_module(module_name)
return module
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
return SimpleNamespace(current_role=role)
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
user = _mock_user(role)
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
mock_service = MagicMock()
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
return mock_service
def _request_payload():
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
assert response == {"result": "success"}
service.validate_load_balancing_credentials.assert_called_once_with(
tenant_id="tenant-123",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM,
credentials={"api_key": "sk-***"},
)
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
assert response == {"result": "error", "error": "invalid credentials"}
def test_validate_credentials_requires_privileged_role(
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
):
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
with pytest.raises(Forbidden):
api.post(provider="openai")
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
provider="openai", config_id="cfg-1"
)
assert response == {"result": "success"}
service.validate_load_balancing_credentials.assert_called_once_with(
tenant_id="tenant-123",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM,
credentials={"api_key": "sk-***"},
config_id="cfg-1",
)

View File

@ -0,0 +1,103 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
# They are intentionally no-ops because the test already patches the required
# behaviors explicitly via @patch and context managers below.
@pytest.fixture
def _mock_cache():
return
@pytest.fixture
def _mock_user_tenant():
return
@pytest.fixture
def client():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
api = Api(app)
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
db.init_app(app)
# Configure session factory used by controller code
with app.app_context():
configure_session_factory(db.engine)
return app.test_client()
@patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
)
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
@patch("controllers.console.workspace.tool_providers.Session")
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
):
# Arrange: reconnect returns tools immediately
mock_reconnect.return_value = ReconnectResult(
authed=True,
tools=json.dumps(
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
),
encrypted_credentials="{}",
)
# Fake service.create_provider -> returns object with id for reload
svc = MagicMock()
create_result = MagicMock()
create_result.id = "provider-1"
svc.create_provider.return_value = create_result
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
mock_session.return_value.__enter__.return_value = MagicMock()
# Patch MCPToolManageService constructed inside controller
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
payload = {
"server_url": "http://example.com/mcp",
"name": "demo",
"icon": "😀",
"icon_type": "emoji",
"icon_background": "#000",
"server_identifier": "demo-sid",
"configuration": {"timeout": 5, "sse_read_timeout": 30},
"headers": {},
"authentication": {},
}
# Act
with (
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
patch(
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
),
):
resp = client.post(
"/console/api/workspaces/current/tool-provider/mcp",
data=json.dumps(payload),
content_type="application/json",
)
# Assert
assert resp.status_code == 200
body = resp.get_json()
assert body.get("id") == "provider-1"
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]

View File

@ -41,6 +41,7 @@ class TestFilePreviewApi:
upload_file = Mock(spec=UploadFile)
upload_file.id = str(uuid.uuid4())
upload_file.name = "test_file.jpg"
upload_file.extension = "jpg"
upload_file.mime_type = "image/jpeg"
upload_file.size = 1024
upload_file.key = "storage/key/test_file.jpg"
@ -210,6 +211,19 @@ class TestFilePreviewApi:
assert mock_upload_file.name in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
def test_build_file_response_html_forces_attachment(self, file_preview_api, mock_upload_file):
"""Test HTML files are forced to download"""
mock_generator = Mock()
mock_upload_file.mime_type = "text/html"
mock_upload_file.name = "unsafe.html"
mock_upload_file.extension = "html"
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
"""Test file response building for audio/video files"""
mock_generator = Mock()

View File

@ -1,11 +1,9 @@
import secrets
from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
STATUS_FORCELIST,
_get_user_provided_host_header,
make_request,
)
@ -14,11 +12,10 @@ from core.helper.ssrf_proxy import (
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_successful_request(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
@ -28,11 +25,10 @@ def test_successful_request(mock_get_client):
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_retry_exceed_max_retries(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
with pytest.raises(Exception) as e:
@ -40,32 +36,6 @@ def test_retry_exceed_max_retries(mock_get_client):
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_retry_logic_success(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
side_effects = []
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
status_code = secrets.choice(STATUS_FORCELIST)
retry_response = MagicMock()
retry_response.status_code = status_code
side_effects.append(retry_response)
side_effects.append(mock_response)
mock_client.send.side_effect = side_effects
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
assert response.status_code == 200
assert mock_client.send.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_client.build_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
class TestGetUserProvidedHostHeader:
"""Tests for _get_user_provided_host_header function."""
@ -111,14 +81,12 @@ def test_host_header_preservation_without_user_header(mock_get_client):
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
# build_request should be called without headers dict containing Host
mock_client.build_request.assert_called_once()
# Host should not be set if not provided by user
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
@ -132,31 +100,10 @@ def test_host_header_preservation_with_user_header(mock_get_client):
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
custom_host = "custom.example.com:8080"
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
assert response.status_code == 200
# Verify build_request was called
mock_client.build_request.assert_called_once()
# Verify the Host header was set on the request object
assert mock_request.headers.get("Host") == custom_host
mock_client.send.assert_called_once_with(mock_request)
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@pytest.mark.parametrize("host_key", ["host", "HOST"])
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
"""Test that Host header is preserved regardless of case."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
assert mock_request.headers.get("Host") == "api.example.com"

View File

@ -96,9 +96,6 @@ class TestToolProviderListCache:
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
mock_redis_client.delete.assert_called_once_with(*mock_keys)
def test_invalidate_cache_no_keys(self, mock_redis_client):
"""Test invalidate cache - no cache keys for tenant"""
tenant_id = "tenant_123"

View File

@ -0,0 +1,327 @@
import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.pgvector.pgvector import (
PGVector,
PGVectorConfig,
)
class TestPGVector(unittest.TestCase):
def setUp(self):
self.config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=False,
)
self.collection_name = "test_collection"
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init(self, mock_pool_class):
"""Test PGVector initialization."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, self.config)
assert pgvector._collection_name == self.collection_name
assert pgvector.table_name == f"embedding_{self.collection_name}"
assert pgvector.get_type() == "pgvector"
assert pgvector.pool is not None
assert pgvector.pg_bigm is False
assert pgvector.index_hash is not None
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init_with_pg_bigm(self, mock_pool_class):
"""Test PGVector initialization with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, config)
assert pgvector.pg_bigm is True
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_basic(self, mock_redis, mock_pool_class):
"""Test basic collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify SQL execution calls
assert mock_cursor.execute.called
# Check that CREATE TABLE was called with correct dimension
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(1536)" in create_table_calls[0][0][0]
# Check that CREATE INDEX was called (dimension <= 2000)
create_index_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
]
assert len(create_index_calls) == 1
# Verify Redis cache was set
mock_redis.set.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
"""Test collection creation with dimension > 2000 (no HNSW index)."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(3072) # Dimension > 2000
# Check that CREATE TABLE was called
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(3072)" in create_table_calls[0][0][0]
# Check that HNSW index was NOT created (dimension > 2000)
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
assert len(hnsw_index_calls) == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
"""Test collection creation with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, config)
pgvector._create_collection(1536)
# Check that pg_bigm index was created
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
assert len(bigm_index_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
"""Test that vector extension is created if it doesn't exist."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
# First call: vector extension doesn't exist
mock_cursor.fetchone.return_value = None
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that CREATE EXTENSION was called
create_extension_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
]
assert len(create_extension_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
"""Test that collection creation is skipped when cache exists."""
# Mock Redis operations - cache exists
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = 1 # Cache exists
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that no SQL was executed (early return due to cache)
assert mock_cursor.execute.call_count == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
"""Test that Redis lock is used during collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify Redis lock was acquired with correct lock name
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
# Verify lock context manager was entered and exited
mock_lock.__enter__.assert_called_once()
mock_lock.__exit__.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_get_cursor_context_manager(self, mock_pool_class):
"""Test that _get_cursor properly manages connection lifecycle."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
with pgvector._get_cursor() as cur:
assert cur == mock_cursor
# Verify connection lifecycle methods were called
mock_pool.getconn.assert_called_once()
mock_cursor.close.assert_called_once()
mock_conn.commit.assert_called_once()
mock_pool.putconn.assert_called_once_with(mock_conn)
@pytest.mark.parametrize(
"invalid_config_override",
[
{"host": ""}, # Test empty host
{"port": 0}, # Test invalid port
{"user": ""}, # Test empty user
{"password": ""}, # Test empty password
{"database": ""}, # Test empty database
{"min_connection": 0}, # Test invalid min_connection
{"max_connection": 0}, # Test invalid max_connection
{"min_connection": 10, "max_connection": 5}, # Test min > max
],
)
def test_config_validation_parametrized(invalid_config_override):
"""Test configuration validation for various invalid inputs using parametrize."""
config = {
"host": "localhost",
"port": 5432,
"user": "test_user",
"password": "test_password",
"database": "test_db",
"min_connection": 1,
"max_connection": 5,
}
config.update(invalid_config_override)
with pytest.raises(ValueError):
PGVectorConfig(**config)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,873 @@
"""
Unit tests for DatasetRetrieval.process_metadata_filter_func.
This module provides comprehensive test coverage for the process_metadata_filter_func
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
filter expressions based on metadata filtering conditions.
Conditions Tested:
==================
1. **String Conditions**: contains, not contains, start with, end with
2. **Equality Conditions**: is / =, is not /
3. **Null Conditions**: empty, not empty
4. **Numeric Comparisons**: before / <, after / >, / <=, / >=
5. **List Conditions**: in
6. **Edge Cases**: None values, different data types (str, int, float)
Test Architecture:
==================
- Direct instantiation of DatasetRetrieval
- Mocking of DatasetDocument model attributes
- Verification of SQLAlchemy filter expressions
- Follows Arrange-Act-Assert (AAA) pattern
Running Tests:
==============
# Run all tests in this module
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
# Run a specific test
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
TestProcessMetadataFilterFunc::test_contains_condition -v
"""
from unittest.mock import MagicMock
import pytest
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
class TestProcessMetadataFilterFunc:
"""
Comprehensive test suite for process_metadata_filter_func method.
This test class validates all metadata filtering conditions supported by
the DatasetRetrieval class, including string operations, numeric comparisons,
null checks, and list operations.
Method Signature:
==================
def process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
) -> list:
The method builds SQLAlchemy filter expressions by:
1. Validating value is not None (except for empty/not empty conditions)
2. Using DatasetDocument.doc_metadata JSON field operations
3. Adding appropriate SQLAlchemy expressions to the filters list
4. Returning the updated filters list
Mocking Strategy:
==================
- Mock DatasetDocument.doc_metadata to avoid database dependencies
- Verify filter expressions are created correctly
- Test with various data types (str, int, float, list)
"""
@pytest.fixture
def retrieval(self):
"""
Create a DatasetRetrieval instance for testing.
Returns:
DatasetRetrieval: Instance to test process_metadata_filter_func
"""
return DatasetRetrieval()
@pytest.fixture
def mock_doc_metadata(self):
"""
Mock the DatasetDocument.doc_metadata JSON field.
The method uses DatasetDocument.doc_metadata[metadata_name] to access
JSON fields. We mock this to avoid database dependencies.
Returns:
Mock: Mocked doc_metadata attribute
"""
mock_metadata_field = MagicMock()
# Create mock for string access
mock_string_access = MagicMock()
mock_string_access.like = MagicMock()
mock_string_access.notlike = MagicMock()
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
mock_string_access.in_ = MagicMock(return_value=MagicMock())
# Create mock for float access (for numeric comparisons)
mock_float_access = MagicMock()
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
# Create mock for null checks
mock_null_access = MagicMock()
mock_null_access.is_ = MagicMock(return_value=MagicMock())
mock_null_access.isnot = MagicMock(return_value=MagicMock())
# Setup __getitem__ to return appropriate mock based on usage
def getitem_side_effect(name):
if name in ["author", "title", "category"]:
return mock_string_access
elif name in ["year", "price", "rating"]:
return mock_float_access
else:
return mock_string_access
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
mock_metadata_field.as_string.return_value = mock_string_access
mock_metadata_field.as_float.return_value = mock_float_access
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
return mock_metadata_field
# ==================== String Condition Tests ====================
def test_contains_condition_string_value(self, retrieval):
"""
Test 'contains' condition with string value.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value% syntax
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = "John"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_contains_condition(self, retrieval):
"""
Test 'not contains' condition.
Verifies:
- Filters list is populated with NOT LIKE expression
- Pattern matching uses %value% syntax with negation
"""
filters = []
sequence = 0
condition = "not contains"
metadata_name = "title"
value = "banned"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_start_with_condition(self, retrieval):
"""
Test 'start with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses value% syntax
"""
filters = []
sequence = 0
condition = "start with"
metadata_name = "category"
value = "tech"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_end_with_condition(self, retrieval):
"""
Test 'end with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value syntax
"""
filters = []
sequence = 0
condition = "end with"
metadata_name = "filename"
value = ".pdf"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Equality Condition Tests ====================
def test_is_condition_with_string_value(self, retrieval):
"""
Test 'is' (=) condition with string value.
Verifies:
- Filters list is populated with equality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = "Jane Doe"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_equals_condition_with_string_value(self, retrieval):
"""
Test '=' condition with string value.
Verifies:
- Same behavior as 'is' condition
- String comparison is used
"""
filters = []
sequence = 0
condition = "="
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_int_value(self, retrieval):
"""
Test 'is' condition with integer value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_float_value(self, retrieval):
"""
Test 'is' condition with float value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "price"
value = 19.99
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_string_value(self, retrieval):
"""
Test 'is not' () condition with string value.
Verifies:
- Filters list is populated with inequality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "author"
value = "Unknown"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_equals_condition(self, retrieval):
"""
Test '' condition with string value.
Verifies:
- Same behavior as 'is not' condition
- Inequality expression is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "category"
value = "archived"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_numeric_value(self, retrieval):
"""
Test 'is not' condition with numeric value.
Verifies:
- Numeric inequality comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Null Condition Tests ====================
def test_empty_condition(self, retrieval):
"""
Test 'empty' condition (null check).
Verifies:
- Filters list is populated with IS NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "empty"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_empty_condition(self, retrieval):
"""
Test 'not empty' condition (not null check).
Verifies:
- Filters list is populated with IS NOT NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "not empty"
metadata_name = "description"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Numeric Comparison Tests ====================
def test_before_condition(self, retrieval):
"""
Test 'before' (<) condition.
Verifies:
- Filters list is populated with less than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "before"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_condition(self, retrieval):
"""
Test '<' condition.
Verifies:
- Same behavior as 'before' condition
- Less than expression is used
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "price"
value = 100.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_after_condition(self, retrieval):
"""
Test 'after' (>) condition.
Verifies:
- Filters list is populated with greater than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "after"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_condition(self, retrieval):
"""
Test '>' condition.
Verifies:
- Same behavior as 'after' condition
- Greater than expression is used
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with less than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "price"
value = 50.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_ascii(self, retrieval):
"""
Test '<=' condition.
Verifies:
- Same behavior as '' condition
- Less than or equal expression is used
"""
filters = []
sequence = 0
condition = "<="
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with greater than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "rating"
value = 3.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_ascii(self, retrieval):
"""
Test '>=' condition.
Verifies:
- Same behavior as '' condition
- Greater than or equal expression is used
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== List/In Condition Tests ====================
def test_in_condition_with_comma_separated_string(self, retrieval):
"""
Test 'in' condition with comma-separated string value.
Verifies:
- String is split into list
- Whitespace is trimmed from each value
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "tech, science, AI "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_list_value(self, retrieval):
"""
Test 'in' condition with list value.
Verifies:
- List is processed correctly
- None values are filtered out
- IN expression is created with valid values
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "tags"
value = ["python", "javascript", None, "golang"]
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_tuple_value(self, retrieval):
"""
Test 'in' condition with tuple value.
Verifies:
- Tuple is processed like a list
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ("tech", "science", "ai")
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_empty_string(self, retrieval):
"""
Test 'in' condition with empty string value.
Verifies:
- Empty string results in literal(False) filter
- No valid values to match
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# Verify it's a literal(False) expression
# This is a bit tricky to test without access to the actual expression
def test_in_condition_with_only_whitespace(self, retrieval):
"""
Test 'in' condition with whitespace-only string value.
Verifies:
- Whitespace-only string results in literal(False) filter
- All values are stripped and filtered out
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = " , , "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_single_string(self, retrieval):
"""
Test 'in' condition with single non-comma string.
Verifies:
- Single string is treated as single-item list
- IN expression is created with one value
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Edge Case Tests ====================
def test_none_value_with_non_empty_condition(self, retrieval):
"""
Test None value with conditions that require value.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values (except empty/not empty)
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0 # No filter added
def test_none_value_with_equals_condition(self, retrieval):
"""
Test None value with 'is' (=) condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_none_value_with_numeric_condition(self, retrieval):
"""
Test None value with numeric comparison condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "year"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_existing_filters_preserved(self, retrieval):
"""
Test that existing filters are preserved.
Verifies:
- Existing filters in the list are not removed
- New filters are appended to the list
"""
existing_filter = MagicMock()
filters = [existing_filter]
sequence = 0
condition = "contains"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 2
assert filters[0] == existing_filter
def test_multiple_filters_accumulated(self, retrieval):
"""
Test multiple calls to accumulate filters.
Verifies:
- Each call adds a new filter to the list
- All filters are preserved across calls
"""
filters = []
# First filter
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
assert len(filters) == 1
# Second filter
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
assert len(filters) == 2
# Third filter
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
assert len(filters) == 3
def test_unknown_condition(self, retrieval):
"""
Test unknown/unsupported condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for unknown conditions
"""
filters = []
sequence = 0
condition = "unknown_condition"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_empty_string_value_with_contains(self, retrieval):
"""
Test empty string value with 'contains' condition.
Verifies:
- Filter is added even with empty string
- LIKE expression is created
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_special_characters_in_value(self, retrieval):
"""
Test special characters in value string.
Verifies:
- Special characters are handled in value
- LIKE expression is created correctly
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "title"
value = "C++ & Python's features"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_zero_value_with_numeric_condition(self, retrieval):
"""
Test zero value with numeric comparison condition.
Verifies:
- Zero is treated as valid value
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "price"
value = 0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_negative_value_with_numeric_condition(self, retrieval):
"""
Test negative value with numeric comparison condition.
Verifies:
- Negative numbers are handled correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "temperature"
value = -10.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_float_value_with_integer_comparison(self, retrieval):
"""
Test float value with numeric comparison condition.
Verifies:
- Float values work correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1

View File

@ -0,0 +1,472 @@
"""
Unit tests for SegmentService.get_segments method.
Tests the retrieval of document segments with pagination and filtering:
- Basic pagination (page, limit)
- Status filtering
- Keyword search
- Ordering by position and id (to avoid duplicate data)
"""
from unittest.mock import Mock, create_autospec, patch
import pytest
from models.dataset import DocumentSegment
class SegmentServiceTestDataFactory:
"""
Factory class for creating test data and mock objects for segment tests.
"""
@staticmethod
def create_segment_mock(
segment_id: str = "segment-123",
document_id: str = "doc-123",
tenant_id: str = "tenant-123",
dataset_id: str = "dataset-123",
position: int = 1,
content: str = "Test content",
status: str = "completed",
**kwargs,
) -> Mock:
"""
Create a mock document segment.
Args:
segment_id: Unique identifier for the segment
document_id: Parent document ID
tenant_id: Tenant ID the segment belongs to
dataset_id: Parent dataset ID
position: Position within the document
content: Segment text content
status: Indexing status
**kwargs: Additional attributes
Returns:
Mock: DocumentSegment mock object
"""
segment = create_autospec(DocumentSegment, instance=True)
segment.id = segment_id
segment.document_id = document_id
segment.tenant_id = tenant_id
segment.dataset_id = dataset_id
segment.position = position
segment.content = content
segment.status = status
for key, value in kwargs.items():
setattr(segment, key, value)
return segment
class TestSegmentServiceGetSegments:
"""
Comprehensive unit tests for SegmentService.get_segments method.
Tests cover:
- Basic pagination functionality
- Status list filtering
- Keyword search filtering
- Ordering (position + id for uniqueness)
- Empty results
- Combined filters
"""
@pytest.fixture
def mock_segment_service_dependencies(self):
"""
Common mock setup for segment service dependencies.
Patches:
- db: Database operations and pagination
- select: SQLAlchemy query builder
"""
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select") as mock_select,
):
yield {
"db": mock_db,
"select": mock_select,
}
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
"""
Test basic pagination functionality.
Verifies:
- Query is built with document_id and tenant_id filters
- Pagination uses correct page and limit parameters
- Returns segments and total count
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
page = 1
limit = 20
# Create mock segments
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="First segment"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=2, content="Second segment"
)
# Mock pagination result
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
# Mock select builder
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
# Assert
assert len(items) == 2
assert total == 2
assert items[0].id == "seg-1"
assert items[1].id == "seg-2"
mock_segment_service_dependencies["db"].paginate.assert_called_once()
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
"""
Test filtering by status list.
Verifies:
- Status list filter is applied to query
- Only segments with matching status are returned
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed", "indexing"]
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 2
assert total == 2
# Verify where was called multiple times (base filters + status filter)
assert mock_query.where.call_count >= 2
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
"""
Test with empty status list.
Verifies:
- Empty status list is handled correctly
- No status filter is applied to avoid WHERE false condition
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = []
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
"""
Test keyword search functionality.
Verifies:
- Keyword filter uses ilike for case-insensitive search
- Search pattern includes wildcards (%keyword%)
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
keyword = "search term"
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", content="This contains search term"
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
# Assert
assert len(items) == 1
assert total == 1
# Verify where was called for base filters + keyword filter
assert mock_query.where.call_count == 2
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
"""
Test ordering by position and id.
Verifies:
- Results are ordered by position ASC
- Results are secondarily ordered by id ASC to ensure uniqueness
- This prevents duplicate data across pages when positions are not unique
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
# Create segments with same position but different ids
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="Content 1"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=1, content="Content 2"
)
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-3", position=2, content="Content 3"
)
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2, segment3]
mock_paginated.total = 3
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert len(items) == 3
assert total == 3
mock_query.order_by.assert_called_once()
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
"""
Test when no segments match the criteria.
Verifies:
- Empty list is returned for items
- Total count is 0
"""
# Arrange
document_id = "non-existent-doc"
tenant_id = "tenant-123"
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert items == []
assert total == 0
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
"""
Test with multiple filters combined.
Verifies:
- All filters work together correctly
- Status list and keyword search both applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed"]
keyword = "important"
page = 2
limit = 10
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1",
status="completed",
content="This is important information",
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=status_list,
keyword=keyword,
page=page,
limit=limit,
)
# Assert
assert len(items) == 1
assert total == 1
# Verify filters: base + status + keyword
assert mock_query.where.call_count == 3
# Verify pagination parameters
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
"""
Test with None status list.
Verifies:
- None status list is handled correctly
- No status filter is applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=None,
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters only, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
"""
Test that max_per_page is correctly set to 100.
Verifies:
- max_per_page parameter is set to 100
- This prevents excessive page sizes
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
limit = 200 # Request more than max_per_page
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
limit=limit,
)
# Assert
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["max_per_page"] == 100

View File

@ -0,0 +1,122 @@
import base64
from unittest.mock import Mock, patch
import pytest
from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextResourceContents,
)
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
from core.tools.mcp_tool.tool import MCPTool
def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool:
identity = ToolIdentity(
author="test",
name="test_mcp_tool",
label=I18nObject(en_US="Test MCP Tool", zh_Hans="测试MCP工具"),
provider="test_provider",
)
entity = ToolEntity(identity=identity, output_schema=output_schema or {})
runtime = Mock(spec=ToolRuntime)
runtime.credentials = {}
return MCPTool(
entity=entity,
runtime=runtime,
tenant_id="test_tenant",
icon="",
server_url="https://server.invalid",
provider_id="provider_1",
headers={},
)
class TestMCPToolInvoke:
@pytest.mark.parametrize(
("content_factory", "mime_type"),
[
(
lambda b64, mt: ImageContent(type="image", data=b64, mimeType=mt),
"image/png",
),
(
lambda b64, mt: AudioContent(type="audio", data=b64, mimeType=mt),
"audio/mpeg",
),
],
)
def test_invoke_image_or_audio_yields_blob(self, content_factory, mime_type) -> None:
tool = _make_mcp_tool()
raw = b"\x00\x01test-bytes\x02"
b64 = base64.b64encode(raw).decode()
content = content_factory(b64, mime_type)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
assert msg.message.blob == raw
assert msg.meta == {"mime_type": mime_type}
def test_invoke_embedded_text_resource_yields_text(self) -> None:
tool = _make_mcp_tool()
text_resource = TextResourceContents(uri="file://test.txt", mimeType="text/plain", text="hello world")
content = EmbeddedResource(type="resource", resource=text_resource)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(msg.message, ToolInvokeMessage.TextMessage)
assert msg.message.text == "hello world"
@pytest.mark.parametrize(
("mime_type", "expected_mime"),
[("application/pdf", "application/pdf"), (None, "application/octet-stream")],
)
def test_invoke_embedded_blob_resource_yields_blob(self, mime_type, expected_mime) -> None:
tool = _make_mcp_tool()
raw = b"binary-data"
b64 = base64.b64encode(raw).decode()
blob_resource = BlobResourceContents(uri="file://doc.bin", mimeType=mime_type, blob=b64)
content = EmbeddedResource(type="resource", resource=blob_resource)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
assert msg.message.blob == raw
assert msg.meta == {"mime_type": expected_mime}
def test_invoke_yields_variables_when_structured_content_and_schema(self) -> None:
tool = _make_mcp_tool(output_schema={"type": "object"})
result = CallToolResult(content=[], structuredContent={"a": 1, "b": "x"})
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
# Expect two variable messages corresponding to keys a and b
assert len(messages) == 2
var_msgs = [m for m in messages if isinstance(m.message, ToolInvokeMessage.VariableMessage)]
assert {m.message.variable_name for m in var_msgs} == {"a", "b"}
# Validate values
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
assert values == {"a": 1, "b": "x"}

View File

@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.11.1"
version = "1.11.2"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },
@ -3072,11 +3072,11 @@ wheels = [
[[package]]
name = "json-repair"
version = "0.54.1"
version = "0.54.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/00/46/d3a4d9a3dad39bb4a2ad16b8adb9fe2e8611b20b71197fe33daa6768e85d/json_repair-0.54.1.tar.gz", hash = "sha256:d010bc31f1fc66e7c36dc33bff5f8902674498ae5cb8e801ad455a53b455ad1d", size = 38555, upload-time = "2025-11-19T14:55:24.265Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b5/86/48b12ac02032f121ac7e5f11a32143edca6c1e3d19ffc54d6fb9ca0aafd0/json_repair-0.54.3.tar.gz", hash = "sha256:e50feec9725e52ac91f12184609754684ac1656119dfbd31de09bdaf9a1d8bf6", size = 38626, upload-time = "2025-12-15T09:41:58.594Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/96/c9aad7ee949cc1bf15df91f347fbc2d3bd10b30b80c7df689ce6fe9332b5/json_repair-0.54.1-py3-none-any.whl", hash = "sha256:016160c5db5d5fe443164927bb58d2dfbba5f43ad85719fa9bc51c713a443ab1", size = 29311, upload-time = "2025-11-19T14:55:22.886Z" },
{ url = "https://files.pythonhosted.org/packages/e9/08/abe317237add63c3e62f18a981bccf92112b431835b43d844aedaf61f4a0/json_repair-0.54.3-py3-none-any.whl", hash = "sha256:4cdc132ee27d4780576f71bf27a113877046224a808bfc17392e079cb344fb81", size = 29357, upload-time = "2025-12-15T09:41:57.436Z" },
]
[[package]]

View File

@ -399,6 +399,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=*
COOKIE_DOMAIN=
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
NEXT_PUBLIC_COOKIE_DOMAIN=
NEXT_PUBLIC_BATCH_CONCURRENCY=5
# ------------------------------
# File Storage Configuration

View File

@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.1
image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.1
image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.1
image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.1
image: langgenius/dify-web:1.11.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -108,6 +108,7 @@ x-shared-env: &shared-api-worker-env
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
COOKIE_DOMAIN: ${COOKIE_DOMAIN:-}
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
NEXT_PUBLIC_BATCH_CONCURRENCY: ${NEXT_PUBLIC_BATCH_CONCURRENCY:-5}
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
@ -692,7 +693,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.1
image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@ -734,7 +735,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.1
image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@ -773,7 +774,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.1
image: langgenius/dify-api:1.11.2
restart: always
environment:
# Use the shared environment variables.
@ -803,7 +804,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.1
image: langgenius/dify-web:1.11.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -54,17 +54,17 @@
"publish:npm": "./scripts/publish.sh"
},
"dependencies": {
"axios": "^1.3.5"
"axios": "^1.13.2"
},
"devDependencies": {
"@eslint/js": "^9.2.0",
"@types/node": "^20.11.30",
"@eslint/js": "^9.39.2",
"@types/node": "^25.0.3",
"@typescript-eslint/eslint-plugin": "^8.50.1",
"@typescript-eslint/parser": "^8.50.1",
"@vitest/coverage-v8": "1.6.1",
"eslint": "^9.2.0",
"@vitest/coverage-v8": "4.0.16",
"eslint": "^9.39.2",
"tsup": "^8.5.1",
"typescript": "^5.4.5",
"vitest": "^1.5.0"
"typescript": "^5.9.3",
"vitest": "^4.0.16"
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,2 @@
onlyBuiltDependencies:
- esbuild

View File

@ -73,3 +73,6 @@ NEXT_PUBLIC_MAX_TREE_DEPTH=50
# The API key of amplitude
NEXT_PUBLIC_AMPLITUDE_API_KEY=
# number of concurrency
NEXT_PUBLIC_BATCH_CONCURRENCY=5

View File

@ -1,6 +1,6 @@
import type { Plan, UsagePlanInfo } from '@/app/components/billing/type'
import type { ProviderContextState } from '@/context/provider-context'
import { merge, noop } from 'lodash-es'
import { merge, noop } from 'es-toolkit/compat'
import { defaultPlan } from '@/app/components/billing/config'
// Avoid being mocked in tests

View File

@ -4,7 +4,7 @@ import type { FC } from 'react'
import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types'
import { RiCalendarLine } from '@remixicon/react'
import dayjs from 'dayjs'
import { noop } from 'lodash-es'
import { noop } from 'es-toolkit/compat'
import * as React from 'react'
import { useCallback } from 'react'
import Picker from '@/app/components/base/date-and-time-picker/date-picker'

View File

@ -1,5 +1,6 @@
import type { ReactNode } from 'react'
import * as React from 'react'
import { AppInitializer } from '@/app/components/app-initializer'
import AmplitudeProvider from '@/app/components/base/amplitude'
import GA, { GaType } from '@/app/components/base/ga'
import Zendesk from '@/app/components/base/zendesk'
@ -7,7 +8,6 @@ import GotoAnything from '@/app/components/goto-anything'
import Header from '@/app/components/header'
import HeaderWrapper from '@/app/components/header/header-wrapper'
import ReadmePanel from '@/app/components/plugins/readme-panel'
import SwrInitializer from '@/app/components/swr-initializer'
import { AppContextProvider } from '@/context/app-context'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ModalContextProvider } from '@/context/modal-context'
@ -20,7 +20,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
<>
<GA gaType={GaType.admin} />
<AmplitudeProvider />
<SwrInitializer>
<AppInitializer>
<AppContextProvider>
<EventEmitterContextProvider>
<ProviderContextProvider>
@ -38,7 +38,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
</EventEmitterContextProvider>
</AppContextProvider>
<Zendesk />
</SwrInitializer>
</AppInitializer>
</>
)
}

View File

@ -1,6 +1,6 @@
'use client'
import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react'
import { noop } from 'lodash-es'
import { noop } from 'es-toolkit/compat'
import Link from 'next/link'
import { useRouter, useSearchParams } from 'next/navigation'
import { useState } from 'react'

View File

@ -1,4 +1,4 @@
import { noop } from 'lodash-es'
import { noop } from 'es-toolkit/compat'
import { useRouter, useSearchParams } from 'next/navigation'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -1,5 +1,5 @@
'use client'
import { noop } from 'lodash-es'
import { noop } from 'es-toolkit/compat'
import Link from 'next/link'
import { useRouter, useSearchParams } from 'next/navigation'
import { useCallback, useState } from 'react'

View File

@ -1,6 +1,6 @@
import type { ResponseError } from '@/service/fetch'
import { RiCloseLine } from '@remixicon/react'
import { noop } from 'lodash-es'
import { noop } from 'es-toolkit/compat'
import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useState } from 'react'

View File

@ -1,9 +1,9 @@
import type { ReactNode } from 'react'
import * as React from 'react'
import { AppInitializer } from '@/app/components/app-initializer'
import AmplitudeProvider from '@/app/components/base/amplitude'
import GA, { GaType } from '@/app/components/base/ga'
import HeaderWrapper from '@/app/components/header/header-wrapper'
import SwrInitor from '@/app/components/swr-initializer'
import { AppContextProvider } from '@/context/app-context'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ModalContextProvider } from '@/context/modal-context'
@ -15,7 +15,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
<>
<GA gaType={GaType.admin} />
<AmplitudeProvider />
<SwrInitor>
<AppInitializer>
<AppContextProvider>
<EventEmitterContextProvider>
<ProviderContextProvider>
@ -30,7 +30,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
</ProviderContextProvider>
</EventEmitterContextProvider>
</AppContextProvider>
</SwrInitor>
</AppInitializer>
</>
)
}

View File

@ -3,7 +3,6 @@
import type { ReactNode } from 'react'
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
import { useCallback, useEffect, useState } from 'react'
import { SWRConfig } from 'swr'
import {
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
@ -11,12 +10,13 @@ import {
import { fetchSetupStatus } from '@/service/common'
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
type SwrInitializerProps = {
type AppInitializerProps = {
children: ReactNode
}
const SwrInitializer = ({
export const AppInitializer = ({
children,
}: SwrInitializerProps) => {
}: AppInitializerProps) => {
const router = useRouter()
const searchParams = useSearchParams()
// Tokens are now stored in cookies, no need to check localStorage
@ -69,20 +69,5 @@ const SwrInitializer = ({
})()
}, [isSetupFinished, router, pathname, searchParams])
return init
? (
<SWRConfig value={{
shouldRetryOnError: false,
revalidateOnFocus: false,
dedupingInterval: 60000,
focusThrottleInterval: 5000,
provider: () => new Map(),
}}
>
{children}
</SWRConfig>
)
: null
return init ? children : null
}
export default SwrInitializer

View File

@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import { RiCloseLine } from '@remixicon/react'
import { noop } from 'lodash-es'
import { noop } from 'es-toolkit/compat'
import * as React from 'react'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -4,7 +4,7 @@ import {
RiAddLine,
RiEditLine,
} from '@remixicon/react'
import { noop } from 'lodash-es'
import { noop } from 'es-toolkit/compat'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'

View File

@ -4,8 +4,8 @@ import type { ExternalDataTool } from '@/models/common'
import type { PromptVariable } from '@/models/debug'
import type { GenRes } from '@/service/debug'
import { useBoolean } from 'ahooks'
import { noop } from 'es-toolkit/compat'
import { produce } from 'immer'
import { noop } from 'lodash-es'
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -47,6 +47,12 @@ const getCheckboxDefaultSelectValue = (value: InputVar['default']) => {
const parseCheckboxSelectValue = (value: string) =>
value === CHECKBOX_DEFAULT_TRUE_VALUE
const normalizeSelectDefaultValue = (inputVar: InputVar) => {
if (inputVar.type === InputVarType.select && inputVar.default === '')
return { ...inputVar, default: undefined }
return inputVar
}
export type IConfigModalProps = {
isCreate?: boolean
payload?: InputVar
@ -67,7 +73,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
}) => {
const { modelConfig } = useContext(ConfigContext)
const { t } = useTranslation()
const [tempPayload, setTempPayload] = useState<InputVar>(() => payload || getNewVarInWorkflow('') as any)
const [tempPayload, setTempPayload] = useState<InputVar>(() => normalizeSelectDefaultValue(payload || getNewVarInWorkflow('') as any))
const { type, label, variable, options, max_length } = tempPayload
const modalRef = useRef<HTMLDivElement>(null)
const appDetail = useAppStore(state => state.appDetail)
@ -182,6 +188,8 @@ const ConfigModal: FC<IConfigModalProps> = ({
const newPayload = produce(tempPayload, (draft) => {
draft.type = type
if (type === InputVarType.select)
draft.default = undefined
if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) {
(Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => {
if (key !== 'max_length')

View File

@ -2,7 +2,7 @@
import type { FC } from 'react'
import type { ExternalDataTool } from '@/models/common'
import copy from 'copy-to-clipboard'
import { noop } from 'lodash-es'
import { noop } from 'es-toolkit/compat'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'

View File

@ -52,7 +52,7 @@ vi.mock('../debug/hooks', () => ({
useFormattingChangedDispatcher: vi.fn(() => vi.fn()),
}))
vi.mock('lodash-es', () => ({
vi.mock('es-toolkit/compat', () => ({
intersectionBy: vi.fn((...arrays) => {
// Mock realistic intersection behavior based on metadata name
const validArrays = arrays.filter(Array.isArray)

View File

@ -8,8 +8,8 @@ import type {
MetadataFilteringModeEnum,
} from '@/app/components/workflow/nodes/knowledge-retrieval/types'
import type { DataSet } from '@/models/datasets'
import { intersectionBy } from 'es-toolkit/compat'
import { produce } from 'immer'
import { intersectionBy } from 'lodash-es'
import * as React from 'react'
import { useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
@ -176,7 +176,7 @@ const DatasetConfig: FC = () => {
}))
}, [setDatasetConfigs, datasetConfigsRef])
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
let operator: ComparisonOperator = ComparisonOperator.is
if (type === MetadataFilteringVariableType.number)
@ -184,6 +184,7 @@ const DatasetConfig: FC = () => {
const newCondition = {
id: uuid4(),
metadata_id: id, // Save metadata.id for reliable reference
name,
comparison_operator: operator,
}

View File

@ -1,6 +1,7 @@
'use client'
import type { FC } from 'react'
import type { ModelParameterModalProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import type { ModelConfig } from '@/app/components/workflow/types'
import type {
DataSet,
@ -8,7 +9,6 @@ import type {
import type {
DatasetConfigs,
} from '@/models/debug'
import { noop } from 'lodash-es'
import { memo, useCallback, useEffect, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Divider from '@/app/components/base/divider'
@ -33,17 +33,20 @@ type Props = {
selectedDatasets?: DataSet[]
isInWorkflow?: boolean
singleRetrievalModelConfig?: ModelConfig
onSingleRetrievalModelChange?: (config: ModelConfig) => void
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
onSingleRetrievalModelChange?: ModelParameterModalProps['setModel']
onSingleRetrievalModelParamsChange?: ModelParameterModalProps['onCompletionParamsChange']
}
const noopModelChange: ModelParameterModalProps['setModel'] = () => {}
const noopParamsChange: ModelParameterModalProps['onCompletionParamsChange'] = () => {}
const ConfigContent: FC<Props> = ({
datasetConfigs,
onChange,
isInWorkflow,
singleRetrievalModelConfig: singleRetrievalConfig = {} as ModelConfig,
onSingleRetrievalModelChange = noop,
onSingleRetrievalModelParamsChange = noop,
onSingleRetrievalModelChange = noopModelChange,
onSingleRetrievalModelParamsChange = noopParamsChange,
selectedDatasets = [],
}) => {
const { t } = useTranslation()

Some files were not shown because too many files have changed in this diff Show More