mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'upstream/main' into feat-agent-mask
This commit is contained in:
commit
75b863f7e5
|
|
@ -0,0 +1,8 @@
|
||||||
|
{
|
||||||
|
"enabledPlugins": {
|
||||||
|
"feature-dev@claude-plugins-official": true,
|
||||||
|
"context7@claude-plugins-official": true,
|
||||||
|
"typescript-lsp@claude-plugins-official": true,
|
||||||
|
"pyright-lsp@claude-plugins-official": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
{
|
|
||||||
"permissions": {
|
|
||||||
"allow": [],
|
|
||||||
"deny": []
|
|
||||||
},
|
|
||||||
"env": {
|
|
||||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
|
||||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
|
||||||
},
|
|
||||||
"enabledMcpjsonServers": [
|
|
||||||
"context7",
|
|
||||||
"sequential-thinking",
|
|
||||||
"github",
|
|
||||||
"fetch",
|
|
||||||
"playwright",
|
|
||||||
"ide"
|
|
||||||
],
|
|
||||||
"enableAllProjectMcpServers": true
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,483 @@
|
||||||
|
---
|
||||||
|
name: component-refactoring
|
||||||
|
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Dify Component Refactoring Skill
|
||||||
|
|
||||||
|
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
|
||||||
|
|
||||||
|
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
|
||||||
|
|
||||||
|
## Quick Reference
|
||||||
|
|
||||||
|
### Commands (run from `web/`)
|
||||||
|
|
||||||
|
Use paths relative to `web/` (e.g., `app/components/...`).
|
||||||
|
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd web
|
||||||
|
|
||||||
|
# Generate refactoring prompt
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
|
||||||
|
# Output refactoring analysis as JSON
|
||||||
|
pnpm refactor-component <path> --json
|
||||||
|
|
||||||
|
# Generate testing prompt (after refactoring)
|
||||||
|
pnpm analyze-component <path>
|
||||||
|
|
||||||
|
# Output testing analysis as JSON
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complexity Analysis
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Analyze component complexity
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
|
||||||
|
# Key metrics to check:
|
||||||
|
# - complexity: normalized score 0-100 (target < 50)
|
||||||
|
# - maxComplexity: highest single function complexity
|
||||||
|
# - lineCount: total lines (target < 300)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complexity Score Interpretation
|
||||||
|
|
||||||
|
| Score | Level | Action |
|
||||||
|
|-------|-------|--------|
|
||||||
|
| 0-25 | 🟢 Simple | Ready for testing |
|
||||||
|
| 26-50 | 🟡 Medium | Consider minor refactoring |
|
||||||
|
| 51-75 | 🟠 Complex | **Refactor before testing** |
|
||||||
|
| 76-100 | 🔴 Very Complex | **Must refactor** |
|
||||||
|
|
||||||
|
## Core Refactoring Patterns
|
||||||
|
|
||||||
|
### Pattern 1: Extract Custom Hooks
|
||||||
|
|
||||||
|
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
|
||||||
|
|
||||||
|
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Complex state logic in component
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
// 50+ lines of state management logic...
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to custom hook
|
||||||
|
// hooks/use-model-config.ts
|
||||||
|
export const useModelConfig = (appId: string) => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
// Related state management logic here
|
||||||
|
|
||||||
|
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const { modelConfig, setModelConfig } = useModelConfig(appId)
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
|
||||||
|
- `web/app/components/app/configuration/debug/hooks.tsx`
|
||||||
|
- `web/app/components/workflow/hooks/use-workflow.ts`
|
||||||
|
|
||||||
|
### Pattern 2: Extract Sub-Components
|
||||||
|
|
||||||
|
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
|
||||||
|
|
||||||
|
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Monolithic JSX with multiple sections
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* 100 lines of header UI */}
|
||||||
|
{/* 100 lines of operations UI */}
|
||||||
|
{/* 100 lines of modals */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Split into focused components
|
||||||
|
// app-info/
|
||||||
|
// ├── index.tsx (orchestration only)
|
||||||
|
// ├── app-header.tsx (header UI)
|
||||||
|
// ├── app-operations.tsx (operations UI)
|
||||||
|
// └── app-modals.tsx (modal management)
|
||||||
|
|
||||||
|
const AppInfo = () => {
|
||||||
|
const { showModal, setShowModal } = useAppInfoModals()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<AppHeader appDetail={appDetail} />
|
||||||
|
<AppOperations onAction={handleAction} />
|
||||||
|
<AppModals show={showModal} onClose={() => setShowModal(null)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/app/components/app/configuration/` directory structure
|
||||||
|
- `web/app/components/workflow/nodes/` per-node organization
|
||||||
|
|
||||||
|
### Pattern 3: Simplify Conditional Logic
|
||||||
|
|
||||||
|
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Deeply nested conditionals
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateChatZh />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateChatJa />
|
||||||
|
default:
|
||||||
|
return <TemplateChatEn />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||||
|
// Another 15 lines...
|
||||||
|
}
|
||||||
|
// More conditions...
|
||||||
|
}, [appDetail, locale])
|
||||||
|
|
||||||
|
// ✅ After: Use lookup tables + early returns
|
||||||
|
const TEMPLATE_MAP = {
|
||||||
|
[AppModeEnum.CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateChatJa,
|
||||||
|
default: TemplateChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.ADVANCED_CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||||
|
// ...
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
|
||||||
|
if (!modeTemplates) return null
|
||||||
|
|
||||||
|
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
|
||||||
|
return <TemplateComponent appDetail={appDetail} />
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 4: Extract API/Data Logic
|
||||||
|
|
||||||
|
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||||
|
|
||||||
|
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: API logic in component
|
||||||
|
const MCPServiceCard = () => {
|
||||||
|
const [basicAppConfig, setBasicAppConfig] = useState({})
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isBasicApp && appId) {
|
||||||
|
(async () => {
|
||||||
|
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||||
|
setBasicAppConfig(res?.model_config || {})
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
}, [appId, isBasicApp])
|
||||||
|
|
||||||
|
// More API-related logic...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to data hook using React Query
|
||||||
|
// use-app-config.ts
|
||||||
|
import { useQuery } from '@tanstack/react-query'
|
||||||
|
import { get } from '@/service/base'
|
||||||
|
|
||||||
|
const NAME_SPACE = 'appConfig'
|
||||||
|
|
||||||
|
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
|
||||||
|
return useQuery({
|
||||||
|
enabled: isBasicApp && !!appId,
|
||||||
|
queryKey: [NAME_SPACE, 'detail', appId],
|
||||||
|
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||||
|
select: data => data?.model_config || {},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const MCPServiceCard = () => {
|
||||||
|
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
|
||||||
|
// UI only
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**React Query Best Practices in Dify**:
|
||||||
|
- Define `NAME_SPACE` for query key organization
|
||||||
|
- Use `enabled` option for conditional fetching
|
||||||
|
- Use `select` for data transformation
|
||||||
|
- Export invalidation hooks: `useInvalidXxx`
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/service/use-workflow.ts`
|
||||||
|
- `web/service/use-common.ts`
|
||||||
|
- `web/service/knowledge/use-dataset.ts`
|
||||||
|
- `web/service/knowledge/use-document.ts`
|
||||||
|
|
||||||
|
### Pattern 5: Extract Modal/Dialog Management
|
||||||
|
|
||||||
|
**When**: Component manages multiple modals with complex open/close states.
|
||||||
|
|
||||||
|
**Dify Convention**: Modals should be extracted with their state management.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Multiple modal states in component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const [showEditModal, setShowEditModal] = useState(false)
|
||||||
|
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
|
||||||
|
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
|
||||||
|
const [showSwitchModal, setShowSwitchModal] = useState(false)
|
||||||
|
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
|
||||||
|
// 5+ more modal states...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to modal management hook
|
||||||
|
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
|
||||||
|
|
||||||
|
const useAppInfoModals = () => {
|
||||||
|
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||||
|
|
||||||
|
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
|
||||||
|
const closeModal = useCallback(() => setActiveModal(null), [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
activeModal,
|
||||||
|
openModal,
|
||||||
|
closeModal,
|
||||||
|
isOpen: (type: ModalType) => activeModal === type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 6: Extract Form Logic
|
||||||
|
|
||||||
|
**When**: Complex form validation, submission handling, or field transformation.
|
||||||
|
|
||||||
|
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ✅ Use existing form infrastructure
|
||||||
|
import { useAppForm } from '@/app/components/base/form'
|
||||||
|
|
||||||
|
const ConfigForm = () => {
|
||||||
|
const form = useAppForm({
|
||||||
|
defaultValues: { name: '', description: '' },
|
||||||
|
onSubmit: handleSubmit,
|
||||||
|
})
|
||||||
|
|
||||||
|
return <form.Provider>...</form.Provider>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dify-Specific Refactoring Guidelines
|
||||||
|
|
||||||
|
### 1. Context Provider Extraction
|
||||||
|
|
||||||
|
**When**: Component provides complex context values with multiple states.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Large context value object
|
||||||
|
const value = {
|
||||||
|
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
|
||||||
|
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
|
||||||
|
// 50+ more properties...
|
||||||
|
}
|
||||||
|
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
|
||||||
|
|
||||||
|
// ✅ After: Split into domain-specific contexts
|
||||||
|
<ModelConfigProvider value={modelConfigValue}>
|
||||||
|
<DatasetConfigProvider value={datasetConfigValue}>
|
||||||
|
<UIConfigProvider value={uiConfigValue}>
|
||||||
|
{children}
|
||||||
|
</UIConfigProvider>
|
||||||
|
</DatasetConfigProvider>
|
||||||
|
</ModelConfigProvider>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Reference**: `web/context/` directory structure
|
||||||
|
|
||||||
|
### 2. Workflow Node Components
|
||||||
|
|
||||||
|
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Keep node logic in `use-interactions.ts`
|
||||||
|
- Extract panel UI to separate files
|
||||||
|
- Use `_base` components for common patterns
|
||||||
|
|
||||||
|
```
|
||||||
|
nodes/<node-type>/
|
||||||
|
├── index.tsx # Node registration
|
||||||
|
├── node.tsx # Node visual component
|
||||||
|
├── panel.tsx # Configuration panel
|
||||||
|
├── use-interactions.ts # Node-specific hooks
|
||||||
|
└── types.ts # Type definitions
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Configuration Components
|
||||||
|
|
||||||
|
**When**: Refactoring app configuration components.
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Separate config sections into subdirectories
|
||||||
|
- Use existing patterns from `web/app/components/app/configuration/`
|
||||||
|
- Keep feature toggles in dedicated components
|
||||||
|
|
||||||
|
### 4. Tool/Plugin Components
|
||||||
|
|
||||||
|
**When**: Refactoring tool-related components (`web/app/components/tools/`).
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Follow existing modal patterns
|
||||||
|
- Use service hooks from `web/service/use-tools.ts`
|
||||||
|
- Keep provider-specific logic isolated
|
||||||
|
|
||||||
|
## Refactoring Workflow
|
||||||
|
|
||||||
|
### Step 1: Generate Refactoring Prompt
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
```
|
||||||
|
|
||||||
|
This command will:
|
||||||
|
- Analyze component complexity and features
|
||||||
|
- Identify specific refactoring actions needed
|
||||||
|
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
|
||||||
|
- Provide detailed requirements based on detected patterns
|
||||||
|
|
||||||
|
### Step 2: Analyze Details
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
```
|
||||||
|
|
||||||
|
Identify:
|
||||||
|
- Total complexity score
|
||||||
|
- Max function complexity
|
||||||
|
- Line count
|
||||||
|
- Features detected (state, effects, API, etc.)
|
||||||
|
|
||||||
|
### Step 3: Plan
|
||||||
|
|
||||||
|
Create a refactoring plan based on detected features:
|
||||||
|
|
||||||
|
| Detected Feature | Refactoring Action |
|
||||||
|
|------------------|-------------------|
|
||||||
|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
|
||||||
|
| `hasAPI: true` | Extract data/service hook |
|
||||||
|
| `hasEvents: true` (many) | Extract event handlers |
|
||||||
|
| `lineCount > 300` | Split into sub-components |
|
||||||
|
| `maxComplexity > 50` | Simplify conditional logic |
|
||||||
|
|
||||||
|
### Step 4: Execute Incrementally
|
||||||
|
|
||||||
|
1. **Extract one piece at a time**
|
||||||
|
2. **Run lint, type-check, and tests after each extraction**
|
||||||
|
3. **Verify functionality before next step**
|
||||||
|
|
||||||
|
```
|
||||||
|
For each extraction:
|
||||||
|
┌────────────────────────────────────────┐
|
||||||
|
│ 1. Extract code │
|
||||||
|
│ 2. Run: pnpm lint:fix │
|
||||||
|
│ 3. Run: pnpm type-check:tsgo │
|
||||||
|
│ 4. Run: pnpm test │
|
||||||
|
│ 5. Test functionality manually │
|
||||||
|
│ 6. PASS? → Next extraction │
|
||||||
|
│ FAIL? → Fix before continuing │
|
||||||
|
└────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 5: Verify
|
||||||
|
|
||||||
|
After refactoring:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Re-run refactor command to verify improvements
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
|
||||||
|
# If complexity < 25 and lines < 200, you'll see:
|
||||||
|
# ✅ COMPONENT IS WELL-STRUCTURED
|
||||||
|
|
||||||
|
# For detailed metrics:
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
|
||||||
|
# Target metrics:
|
||||||
|
# - complexity < 50
|
||||||
|
# - lineCount < 300
|
||||||
|
# - maxComplexity < 30
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Mistakes to Avoid
|
||||||
|
|
||||||
|
### ❌ Over-Engineering
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Too many tiny hooks
|
||||||
|
const useButtonText = () => useState('Click')
|
||||||
|
const useButtonDisabled = () => useState(false)
|
||||||
|
const useButtonLoading = () => useState(false)
|
||||||
|
|
||||||
|
// ✅ Cohesive hook with related state
|
||||||
|
const useButtonState = () => {
|
||||||
|
const [text, setText] = useState('Click')
|
||||||
|
const [disabled, setDisabled] = useState(false)
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
return { text, setText, disabled, setDisabled, loading, setLoading }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### ❌ Breaking Existing Patterns
|
||||||
|
|
||||||
|
- Follow existing directory structures
|
||||||
|
- Maintain naming conventions
|
||||||
|
- Preserve export patterns for compatibility
|
||||||
|
|
||||||
|
### ❌ Premature Abstraction
|
||||||
|
|
||||||
|
- Only extract when there's clear complexity benefit
|
||||||
|
- Don't create abstractions for single-use code
|
||||||
|
- Keep refactored code in the same domain area
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
### Dify Codebase Examples
|
||||||
|
|
||||||
|
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
|
||||||
|
- **Component splitting**: `web/app/components/app/configuration/`
|
||||||
|
- **Service hooks**: `web/service/use-*.ts`
|
||||||
|
- **Workflow patterns**: `web/app/components/workflow/hooks/`
|
||||||
|
- **Form patterns**: `web/app/components/base/form/`
|
||||||
|
|
||||||
|
### Related Skills
|
||||||
|
|
||||||
|
- `frontend-testing` - For testing refactored components
|
||||||
|
- `web/testing/testing.md` - Testing specification
|
||||||
|
|
@ -0,0 +1,493 @@
|
||||||
|
# Complexity Reduction Patterns
|
||||||
|
|
||||||
|
This document provides patterns for reducing cognitive complexity in Dify React components.
|
||||||
|
|
||||||
|
## Understanding Complexity
|
||||||
|
|
||||||
|
### SonarJS Cognitive Complexity
|
||||||
|
|
||||||
|
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
|
||||||
|
|
||||||
|
- **Total Complexity**: Sum of all functions' complexity in the file
|
||||||
|
- **Max Complexity**: Highest single function complexity
|
||||||
|
|
||||||
|
### What Increases Complexity
|
||||||
|
|
||||||
|
| Pattern | Complexity Impact |
|
||||||
|
|---------|-------------------|
|
||||||
|
| `if/else` | +1 per branch |
|
||||||
|
| Nested conditions | +1 per nesting level |
|
||||||
|
| `switch/case` | +1 per case |
|
||||||
|
| `for/while/do` | +1 per loop |
|
||||||
|
| `&&`/`||` chains | +1 per operator |
|
||||||
|
| Nested callbacks | +1 per nesting level |
|
||||||
|
| `try/catch` | +1 per catch |
|
||||||
|
| Ternary expressions | +1 per nesting |
|
||||||
|
|
||||||
|
## Pattern 1: Replace Conditionals with Lookup Tables
|
||||||
|
|
||||||
|
**Before** (complexity: ~15):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateChatZh appDetail={appDetail} />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateChatJa appDetail={appDetail} />
|
||||||
|
default:
|
||||||
|
return <TemplateChatEn appDetail={appDetail} />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateAdvancedChatZh appDetail={appDetail} />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateAdvancedChatJa appDetail={appDetail} />
|
||||||
|
default:
|
||||||
|
return <TemplateAdvancedChatEn appDetail={appDetail} />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
|
||||||
|
// Similar pattern...
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~3):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Define lookup table outside component
|
||||||
|
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
|
||||||
|
[AppModeEnum.CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateChatJa,
|
||||||
|
default: TemplateChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.ADVANCED_CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
|
||||||
|
default: TemplateAdvancedChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.WORKFLOW]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateWorkflowZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateWorkflowJa,
|
||||||
|
default: TemplateWorkflowEn,
|
||||||
|
},
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean component logic
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (!appDetail?.mode) return null
|
||||||
|
|
||||||
|
const templates = TEMPLATE_MAP[appDetail.mode]
|
||||||
|
if (!templates) return null
|
||||||
|
|
||||||
|
const TemplateComponent = templates[locale] ?? templates.default
|
||||||
|
return <TemplateComponent appDetail={appDetail} />
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 2: Use Early Returns
|
||||||
|
|
||||||
|
**Before** (complexity: ~10):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const handleSubmit = () => {
|
||||||
|
if (isValid) {
|
||||||
|
if (hasChanges) {
|
||||||
|
if (isConnected) {
|
||||||
|
submitData()
|
||||||
|
} else {
|
||||||
|
showConnectionError()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showNoChangesMessage()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showValidationError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~4):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const handleSubmit = () => {
|
||||||
|
if (!isValid) {
|
||||||
|
showValidationError()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasChanges) {
|
||||||
|
showNoChangesMessage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isConnected) {
|
||||||
|
showConnectionError()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
submitData()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 3: Extract Complex Conditions
|
||||||
|
|
||||||
|
**Before** (complexity: high):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const canPublish = (() => {
|
||||||
|
if (mode !== AppModeEnum.COMPLETION) {
|
||||||
|
if (!isAdvancedMode)
|
||||||
|
return true
|
||||||
|
|
||||||
|
if (modelModeType === ModelModeType.completion) {
|
||||||
|
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
|
||||||
|
return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !promptEmpty
|
||||||
|
})()
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract to named functions
|
||||||
|
const canPublishInCompletionMode = () => !promptEmpty
|
||||||
|
|
||||||
|
const canPublishInChatMode = () => {
|
||||||
|
if (!isAdvancedMode) return true
|
||||||
|
if (modelModeType !== ModelModeType.completion) return true
|
||||||
|
return hasSetBlockStatus.history && hasSetBlockStatus.query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean main logic
|
||||||
|
const canPublish = mode === AppModeEnum.COMPLETION
|
||||||
|
? canPublishInCompletionMode()
|
||||||
|
: canPublishInChatMode()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 4: Replace Chained Ternaries
|
||||||
|
|
||||||
|
**Before** (complexity: ~5):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const statusText = serverActivated
|
||||||
|
? t('status.running')
|
||||||
|
: serverPublished
|
||||||
|
? t('status.inactive')
|
||||||
|
: appUnpublished
|
||||||
|
? t('status.unpublished')
|
||||||
|
: t('status.notConfigured')
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~2):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const getStatusText = () => {
|
||||||
|
if (serverActivated) return t('status.running')
|
||||||
|
if (serverPublished) return t('status.inactive')
|
||||||
|
if (appUnpublished) return t('status.unpublished')
|
||||||
|
return t('status.notConfigured')
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusText = getStatusText()
|
||||||
|
```
|
||||||
|
|
||||||
|
Or use lookup:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const STATUS_TEXT_MAP = {
|
||||||
|
running: 'status.running',
|
||||||
|
inactive: 'status.inactive',
|
||||||
|
unpublished: 'status.unpublished',
|
||||||
|
notConfigured: 'status.notConfigured',
|
||||||
|
} as const
|
||||||
|
|
||||||
|
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
|
||||||
|
if (serverActivated) return 'running'
|
||||||
|
if (serverPublished) return 'inactive'
|
||||||
|
if (appUnpublished) return 'unpublished'
|
||||||
|
return 'notConfigured'
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 5: Flatten Nested Loops
|
||||||
|
|
||||||
|
**Before** (complexity: high):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const processData = (items: Item[]) => {
|
||||||
|
const results: ProcessedItem[] = []
|
||||||
|
|
||||||
|
for (const item of items) {
|
||||||
|
if (item.isValid) {
|
||||||
|
for (const child of item.children) {
|
||||||
|
if (child.isActive) {
|
||||||
|
for (const prop of child.properties) {
|
||||||
|
if (prop.value !== null) {
|
||||||
|
results.push({
|
||||||
|
itemId: item.id,
|
||||||
|
childId: child.id,
|
||||||
|
propValue: prop.value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Use functional approach
|
||||||
|
const processData = (items: Item[]) => {
|
||||||
|
return items
|
||||||
|
.filter(item => item.isValid)
|
||||||
|
.flatMap(item =>
|
||||||
|
item.children
|
||||||
|
.filter(child => child.isActive)
|
||||||
|
.flatMap(child =>
|
||||||
|
child.properties
|
||||||
|
.filter(prop => prop.value !== null)
|
||||||
|
.map(prop => ({
|
||||||
|
itemId: item.id,
|
||||||
|
childId: child.id,
|
||||||
|
propValue: prop.value,
|
||||||
|
}))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 6: Extract Event Handler Logic
|
||||||
|
|
||||||
|
**Before** (complexity: high in component):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const Component = () => {
|
||||||
|
const handleSelect = (data: DataSet[]) => {
|
||||||
|
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
|
||||||
|
hideSelectDataSet()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
formattingChangedDispatcher()
|
||||||
|
let newDatasets = data
|
||||||
|
if (data.find(item => !item.name)) {
|
||||||
|
const newSelected = produce(data, (draft) => {
|
||||||
|
data.forEach((item, index) => {
|
||||||
|
if (!item.name) {
|
||||||
|
const newItem = dataSets.find(i => i.id === item.id)
|
||||||
|
if (newItem)
|
||||||
|
draft[index] = newItem
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
setDataSets(newSelected)
|
||||||
|
newDatasets = newSelected
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
setDataSets(data)
|
||||||
|
}
|
||||||
|
hideSelectDataSet()
|
||||||
|
|
||||||
|
// 40 more lines of logic...
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract to hook or utility
|
||||||
|
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
|
||||||
|
const normalizeSelection = (data: DataSet[]) => {
|
||||||
|
const hasUnloadedItem = data.some(item => !item.name)
|
||||||
|
if (!hasUnloadedItem) return data
|
||||||
|
|
||||||
|
return produce(data, (draft) => {
|
||||||
|
data.forEach((item, index) => {
|
||||||
|
if (!item.name) {
|
||||||
|
const existing = dataSets.find(i => i.id === item.id)
|
||||||
|
if (existing) draft[index] = existing
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasSelectionChanged = (newData: DataSet[]) => {
|
||||||
|
return !isEqual(
|
||||||
|
newData.map(item => item.id),
|
||||||
|
dataSets.map(item => item.id)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return { normalizeSelection, hasSelectionChanged }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const Component = () => {
|
||||||
|
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
|
||||||
|
|
||||||
|
const handleSelect = (data: DataSet[]) => {
|
||||||
|
if (!hasSelectionChanged(data)) {
|
||||||
|
hideSelectDataSet()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
formattingChangedDispatcher()
|
||||||
|
const normalized = normalizeSelection(data)
|
||||||
|
setDataSets(normalized)
|
||||||
|
hideSelectDataSet()
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 7: Reduce Boolean Logic Complexity
|
||||||
|
|
||||||
|
**Before** (complexity: ~8):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const toggleDisabled = hasInsufficientPermissions
|
||||||
|
|| appUnpublished
|
||||||
|
|| missingStartNode
|
||||||
|
|| triggerModeDisabled
|
||||||
|
|| (isAdvancedApp && !currentWorkflow?.graph)
|
||||||
|
|| (isBasicApp && !basicAppConfig.updated_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~3):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract meaningful boolean functions
|
||||||
|
const isAppReady = () => {
|
||||||
|
if (isAdvancedApp) return !!currentWorkflow?.graph
|
||||||
|
return !!basicAppConfig.updated_at
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasRequiredPermissions = () => {
|
||||||
|
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
|
||||||
|
}
|
||||||
|
|
||||||
|
const canToggle = () => {
|
||||||
|
if (!hasRequiredPermissions()) return false
|
||||||
|
if (!isAppReady()) return false
|
||||||
|
if (missingStartNode) return false
|
||||||
|
if (triggerModeDisabled) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleDisabled = !canToggle()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 8: Simplify useMemo/useCallback Dependencies
|
||||||
|
|
||||||
|
**Before** (complexity: multiple recalculations):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const payload = useMemo(() => {
|
||||||
|
let parameters: Parameter[] = []
|
||||||
|
let outputParameters: OutputParameter[] = []
|
||||||
|
|
||||||
|
if (!published) {
|
||||||
|
parameters = (inputs || []).map((item) => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
form: 'llm',
|
||||||
|
required: item.required,
|
||||||
|
type: item.type,
|
||||||
|
}))
|
||||||
|
outputParameters = (outputs || []).map((item) => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
type: item.value_type,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
else if (detail && detail.tool) {
|
||||||
|
parameters = (inputs || []).map((item) => ({
|
||||||
|
// Complex transformation...
|
||||||
|
}))
|
||||||
|
outputParameters = (outputs || []).map((item) => ({
|
||||||
|
// Complex transformation...
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
icon: detail?.icon || icon,
|
||||||
|
label: detail?.label || name,
|
||||||
|
// ...more fields
|
||||||
|
}
|
||||||
|
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: separated concerns):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Separate transformations
|
||||||
|
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
|
||||||
|
return useMemo(() => {
|
||||||
|
if (!published) {
|
||||||
|
return inputs.map(item => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
form: 'llm',
|
||||||
|
required: item.required,
|
||||||
|
type: item.type,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!detail?.tool) return []
|
||||||
|
|
||||||
|
return inputs.map(item => ({
|
||||||
|
name: item.variable,
|
||||||
|
required: item.required,
|
||||||
|
type: item.type === 'paragraph' ? 'string' : item.type,
|
||||||
|
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
|
||||||
|
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
|
||||||
|
}))
|
||||||
|
}, [inputs, detail, published])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component uses hook
|
||||||
|
const parameters = useParameterTransform(inputs, detail, published)
|
||||||
|
const outputParameters = useOutputTransform(outputs, detail, published)
|
||||||
|
|
||||||
|
const payload = useMemo(() => ({
|
||||||
|
icon: detail?.icon || icon,
|
||||||
|
label: detail?.label || name,
|
||||||
|
parameters,
|
||||||
|
outputParameters,
|
||||||
|
// ...
|
||||||
|
}), [detail, icon, name, parameters, outputParameters])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Target Metrics After Refactoring
|
||||||
|
|
||||||
|
| Metric | Target |
|
||||||
|
|--------|--------|
|
||||||
|
| Total Complexity | < 50 |
|
||||||
|
| Max Function Complexity | < 30 |
|
||||||
|
| Function Length | < 30 lines |
|
||||||
|
| Nesting Depth | ≤ 3 levels |
|
||||||
|
| Conditional Chains | ≤ 3 conditions |
|
||||||
|
|
@ -0,0 +1,477 @@
|
||||||
|
# Component Splitting Patterns
|
||||||
|
|
||||||
|
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
|
||||||
|
|
||||||
|
## When to Split Components
|
||||||
|
|
||||||
|
Split a component when you identify:
|
||||||
|
|
||||||
|
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
|
||||||
|
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
|
||||||
|
1. **Repeated patterns** - Similar UI structures used multiple times
|
||||||
|
1. **300+ lines** - Component exceeds manageable size
|
||||||
|
1. **Modal clusters** - Multiple modals rendered in one component
|
||||||
|
|
||||||
|
## Splitting Strategies
|
||||||
|
|
||||||
|
### Strategy 1: Section-Based Splitting
|
||||||
|
|
||||||
|
Identify visual sections and extract each as a component.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Monolithic component (500+ lines)
|
||||||
|
const ConfigurationPage = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Header Section - 50 lines */}
|
||||||
|
<div className="header">
|
||||||
|
<h1>{t('configuration.title')}</h1>
|
||||||
|
<div className="actions">
|
||||||
|
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||||
|
<ModelParameterModal ... />
|
||||||
|
<AppPublisher ... />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Config Section - 200 lines */}
|
||||||
|
<div className="config">
|
||||||
|
<Config />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Debug Section - 150 lines */}
|
||||||
|
<div className="debug">
|
||||||
|
<Debug ... />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Modals Section - 100 lines */}
|
||||||
|
{showSelectDataSet && <SelectDataSet ... />}
|
||||||
|
{showHistoryModal && <EditHistoryModal ... />}
|
||||||
|
{showUseGPT4Confirm && <Confirm ... />}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Split into focused components
|
||||||
|
// configuration/
|
||||||
|
// ├── index.tsx (orchestration)
|
||||||
|
// ├── configuration-header.tsx
|
||||||
|
// ├── configuration-content.tsx
|
||||||
|
// ├── configuration-debug.tsx
|
||||||
|
// └── configuration-modals.tsx
|
||||||
|
|
||||||
|
// configuration-header.tsx
|
||||||
|
interface ConfigurationHeaderProps {
|
||||||
|
isAdvancedMode: boolean
|
||||||
|
onPublish: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
|
||||||
|
isAdvancedMode,
|
||||||
|
onPublish,
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="header">
|
||||||
|
<h1>{t('configuration.title')}</h1>
|
||||||
|
<div className="actions">
|
||||||
|
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||||
|
<ModelParameterModal ... />
|
||||||
|
<AppPublisher onPublish={onPublish} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// index.tsx (orchestration only)
|
||||||
|
const ConfigurationPage = () => {
|
||||||
|
const { modelConfig, setModelConfig } = useModelConfig()
|
||||||
|
const { activeModal, openModal, closeModal } = useModalState()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<ConfigurationHeader
|
||||||
|
isAdvancedMode={isAdvancedMode}
|
||||||
|
onPublish={handlePublish}
|
||||||
|
/>
|
||||||
|
<ConfigurationContent
|
||||||
|
modelConfig={modelConfig}
|
||||||
|
onConfigChange={setModelConfig}
|
||||||
|
/>
|
||||||
|
{!isMobile && (
|
||||||
|
<ConfigurationDebug
|
||||||
|
inputs={inputs}
|
||||||
|
onSetting={handleSetting}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<ConfigurationModals
|
||||||
|
activeModal={activeModal}
|
||||||
|
onClose={closeModal}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 2: Conditional Block Extraction
|
||||||
|
|
||||||
|
Extract large conditional rendering blocks.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Large conditional blocks
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{expand ? (
|
||||||
|
<div className="expanded">
|
||||||
|
{/* 100 lines of expanded view */}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="collapsed">
|
||||||
|
{/* 50 lines of collapsed view */}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Separate view components
|
||||||
|
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="expanded">
|
||||||
|
{/* Clean, focused expanded view */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="collapsed">
|
||||||
|
{/* Clean, focused collapsed view */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{expand
|
||||||
|
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
|
||||||
|
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 3: Modal Extraction
|
||||||
|
|
||||||
|
Extract modals with their trigger logic.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Multiple modals in one component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const [showEdit, setShowEdit] = useState(false)
|
||||||
|
const [showDuplicate, setShowDuplicate] = useState(false)
|
||||||
|
const [showDelete, setShowDelete] = useState(false)
|
||||||
|
const [showSwitch, setShowSwitch] = useState(false)
|
||||||
|
|
||||||
|
const onEdit = async (data) => { /* 20 lines */ }
|
||||||
|
const onDuplicate = async (data) => { /* 20 lines */ }
|
||||||
|
const onDelete = async () => { /* 15 lines */ }
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Main content */}
|
||||||
|
|
||||||
|
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
|
||||||
|
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
|
||||||
|
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
|
||||||
|
{showSwitch && <SwitchModal ... />}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Modal manager component
|
||||||
|
// app-info-modals.tsx
|
||||||
|
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
|
||||||
|
|
||||||
|
interface AppInfoModalsProps {
|
||||||
|
appDetail: AppDetail
|
||||||
|
activeModal: ModalType
|
||||||
|
onClose: () => void
|
||||||
|
onSuccess: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfoModals: FC<AppInfoModalsProps> = ({
|
||||||
|
appDetail,
|
||||||
|
activeModal,
|
||||||
|
onClose,
|
||||||
|
onSuccess,
|
||||||
|
}) => {
|
||||||
|
const handleEdit = async (data) => { /* logic */ }
|
||||||
|
const handleDuplicate = async (data) => { /* logic */ }
|
||||||
|
const handleDelete = async () => { /* logic */ }
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{activeModal === 'edit' && (
|
||||||
|
<EditModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onConfirm={handleEdit}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'duplicate' && (
|
||||||
|
<DuplicateModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onConfirm={handleDuplicate}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'delete' && (
|
||||||
|
<DeleteConfirm
|
||||||
|
onConfirm={handleDelete}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'switch' && (
|
||||||
|
<SwitchModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parent component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const { activeModal, openModal, closeModal } = useModalState()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Main content with openModal triggers */}
|
||||||
|
<Button onClick={() => openModal('edit')}>Edit</Button>
|
||||||
|
|
||||||
|
<AppInfoModals
|
||||||
|
appDetail={appDetail}
|
||||||
|
activeModal={activeModal}
|
||||||
|
onClose={closeModal}
|
||||||
|
onSuccess={handleSuccess}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 4: List Item Extraction
|
||||||
|
|
||||||
|
Extract repeated item rendering.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Inline item rendering
|
||||||
|
const OperationsList = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{operations.map(op => (
|
||||||
|
<div key={op.id} className="operation-item">
|
||||||
|
<span className="icon">{op.icon}</span>
|
||||||
|
<span className="title">{op.title}</span>
|
||||||
|
<span className="description">{op.description}</span>
|
||||||
|
<button onClick={() => op.onClick()}>
|
||||||
|
{op.actionLabel}
|
||||||
|
</button>
|
||||||
|
{op.badge && <Badge>{op.badge}</Badge>}
|
||||||
|
{/* More complex rendering... */}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extracted item component
|
||||||
|
interface OperationItemProps {
|
||||||
|
operation: Operation
|
||||||
|
onAction: (id: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="operation-item">
|
||||||
|
<span className="icon">{operation.icon}</span>
|
||||||
|
<span className="title">{operation.title}</span>
|
||||||
|
<span className="description">{operation.description}</span>
|
||||||
|
<button onClick={() => onAction(operation.id)}>
|
||||||
|
{operation.actionLabel}
|
||||||
|
</button>
|
||||||
|
{operation.badge && <Badge>{operation.badge}</Badge>}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const OperationsList = () => {
|
||||||
|
const handleAction = useCallback((id: string) => {
|
||||||
|
const op = operations.find(o => o.id === id)
|
||||||
|
op?.onClick()
|
||||||
|
}, [operations])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{operations.map(op => (
|
||||||
|
<OperationItem
|
||||||
|
key={op.id}
|
||||||
|
operation={op}
|
||||||
|
onAction={handleAction}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure Patterns
|
||||||
|
|
||||||
|
### Pattern A: Flat Structure (Simple Components)
|
||||||
|
|
||||||
|
For components with 2-3 sub-components:
|
||||||
|
|
||||||
|
```
|
||||||
|
component-name/
|
||||||
|
├── index.tsx # Main component
|
||||||
|
├── sub-component-a.tsx
|
||||||
|
├── sub-component-b.tsx
|
||||||
|
└── types.ts # Shared types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern B: Nested Structure (Complex Components)
|
||||||
|
|
||||||
|
For components with many sub-components:
|
||||||
|
|
||||||
|
```
|
||||||
|
component-name/
|
||||||
|
├── index.tsx # Main orchestration
|
||||||
|
├── types.ts # Shared types
|
||||||
|
├── hooks/
|
||||||
|
│ ├── use-feature-a.ts
|
||||||
|
│ └── use-feature-b.ts
|
||||||
|
├── components/
|
||||||
|
│ ├── header/
|
||||||
|
│ │ └── index.tsx
|
||||||
|
│ ├── content/
|
||||||
|
│ │ └── index.tsx
|
||||||
|
│ └── modals/
|
||||||
|
│ └── index.tsx
|
||||||
|
└── utils/
|
||||||
|
└── helpers.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern C: Feature-Based Structure (Dify Standard)
|
||||||
|
|
||||||
|
Following Dify's existing patterns:
|
||||||
|
|
||||||
|
```
|
||||||
|
configuration/
|
||||||
|
├── index.tsx # Main page component
|
||||||
|
├── base/ # Base/shared components
|
||||||
|
│ ├── feature-panel/
|
||||||
|
│ ├── group-name/
|
||||||
|
│ └── operation-btn/
|
||||||
|
├── config/ # Config section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ ├── agent/
|
||||||
|
│ └── automatic/
|
||||||
|
├── dataset-config/ # Dataset section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ ├── card-item/
|
||||||
|
│ └── params-config/
|
||||||
|
├── debug/ # Debug section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ └── hooks.tsx
|
||||||
|
└── hooks/ # Shared hooks
|
||||||
|
└── use-advanced-prompt-config.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
## Props Design
|
||||||
|
|
||||||
|
### Minimal Props Principle
|
||||||
|
|
||||||
|
Pass only what's needed:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Bad: Passing entire objects when only some fields needed
|
||||||
|
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
|
||||||
|
|
||||||
|
// ✅ Good: Destructure to minimum required
|
||||||
|
<ConfigHeader
|
||||||
|
appName={appDetail.name}
|
||||||
|
isAdvancedMode={modelConfig.isAdvanced}
|
||||||
|
onPublish={handlePublish}
|
||||||
|
/>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Callback Props Pattern
|
||||||
|
|
||||||
|
Use callbacks for child-to-parent communication:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Parent
|
||||||
|
const Parent = () => {
|
||||||
|
const [value, setValue] = useState('')
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Child
|
||||||
|
value={value}
|
||||||
|
onChange={setValue}
|
||||||
|
onSubmit={handleSubmit}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Child
|
||||||
|
interface ChildProps {
|
||||||
|
value: string
|
||||||
|
onChange: (value: string) => void
|
||||||
|
onSubmit: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<input value={value} onChange={e => onChange(e.target.value)} />
|
||||||
|
<button onClick={onSubmit}>Submit</button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Render Props for Flexibility
|
||||||
|
|
||||||
|
When sub-components need parent context:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface ListProps<T> {
|
||||||
|
items: T[]
|
||||||
|
renderItem: (item: T, index: number) => React.ReactNode
|
||||||
|
renderEmpty?: () => React.ReactNode
|
||||||
|
}
|
||||||
|
|
||||||
|
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
|
||||||
|
if (items.length === 0 && renderEmpty) {
|
||||||
|
return <>{renderEmpty()}</>
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{items.map((item, index) => renderItem(item, index))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage
|
||||||
|
<List
|
||||||
|
items={operations}
|
||||||
|
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
|
||||||
|
renderEmpty={() => <EmptyState message="No operations" />}
|
||||||
|
/>
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,317 @@
|
||||||
|
# Hook Extraction Patterns
|
||||||
|
|
||||||
|
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
|
||||||
|
|
||||||
|
## When to Extract Hooks
|
||||||
|
|
||||||
|
Extract a custom hook when you identify:
|
||||||
|
|
||||||
|
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
|
||||||
|
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
|
||||||
|
1. **Business logic** - Data transformations, validations, or calculations
|
||||||
|
1. **Reusable patterns** - Logic that appears in multiple components
|
||||||
|
|
||||||
|
## Extraction Process
|
||||||
|
|
||||||
|
### Step 1: Identify State Groups
|
||||||
|
|
||||||
|
Look for state variables that are logically related:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ These belong together - extract to hook
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
|
||||||
|
|
||||||
|
// These are model-related state that should be in useModelConfig()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Identify Related Effects
|
||||||
|
|
||||||
|
Find effects that modify the grouped state:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ These effects belong with the state above
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasFetchedDetail && !modelModeType) {
|
||||||
|
const mode = currModel?.model_properties.mode
|
||||||
|
if (mode) {
|
||||||
|
const newModelConfig = produce(modelConfig, (draft) => {
|
||||||
|
draft.mode = mode
|
||||||
|
})
|
||||||
|
setModelConfig(newModelConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Create the Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// hooks/use-model-config.ts
|
||||||
|
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
|
import type { ModelConfig } from '@/models/debug'
|
||||||
|
import { produce } from 'immer'
|
||||||
|
import { useEffect, useState } from 'react'
|
||||||
|
import { ModelModeType } from '@/types/app'
|
||||||
|
|
||||||
|
interface UseModelConfigParams {
|
||||||
|
initialConfig?: Partial<ModelConfig>
|
||||||
|
currModel?: { model_properties?: { mode?: ModelModeType } }
|
||||||
|
hasFetchedDetail: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UseModelConfigReturn {
|
||||||
|
modelConfig: ModelConfig
|
||||||
|
setModelConfig: (config: ModelConfig) => void
|
||||||
|
completionParams: FormValue
|
||||||
|
setCompletionParams: (params: FormValue) => void
|
||||||
|
modelModeType: ModelModeType
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useModelConfig = ({
|
||||||
|
initialConfig,
|
||||||
|
currModel,
|
||||||
|
hasFetchedDetail,
|
||||||
|
}: UseModelConfigParams): UseModelConfigReturn => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>({
|
||||||
|
provider: 'langgenius/openai/openai',
|
||||||
|
model_id: 'gpt-3.5-turbo',
|
||||||
|
mode: ModelModeType.unset,
|
||||||
|
// ... default values
|
||||||
|
...initialConfig,
|
||||||
|
})
|
||||||
|
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
const modelModeType = modelConfig.mode
|
||||||
|
|
||||||
|
// Fill old app data missing model mode
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasFetchedDetail && !modelModeType) {
|
||||||
|
const mode = currModel?.model_properties?.mode
|
||||||
|
if (mode) {
|
||||||
|
setModelConfig(produce(modelConfig, (draft) => {
|
||||||
|
draft.mode = mode
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [hasFetchedDetail, modelModeType, currModel])
|
||||||
|
|
||||||
|
return {
|
||||||
|
modelConfig,
|
||||||
|
setModelConfig,
|
||||||
|
completionParams,
|
||||||
|
setCompletionParams,
|
||||||
|
modelModeType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Update Component
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Before: 50+ lines of state management
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
// ... lots of related state and effects
|
||||||
|
}
|
||||||
|
|
||||||
|
// After: Clean component
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const {
|
||||||
|
modelConfig,
|
||||||
|
setModelConfig,
|
||||||
|
completionParams,
|
||||||
|
setCompletionParams,
|
||||||
|
modelModeType,
|
||||||
|
} = useModelConfig({
|
||||||
|
currModel,
|
||||||
|
hasFetchedDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Component now focuses on UI
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Naming Conventions
|
||||||
|
|
||||||
|
### Hook Names
|
||||||
|
|
||||||
|
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
|
||||||
|
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
|
||||||
|
- Include domain: `useWorkflowVariables`, `useMCPServer`
|
||||||
|
|
||||||
|
### File Names
|
||||||
|
|
||||||
|
- Kebab-case: `use-model-config.ts`
|
||||||
|
- Place in `hooks/` subdirectory when multiple hooks exist
|
||||||
|
- Place alongside component for single-use hooks
|
||||||
|
|
||||||
|
### Return Type Names
|
||||||
|
|
||||||
|
- Suffix with `Return`: `UseModelConfigReturn`
|
||||||
|
- Suffix params with `Params`: `UseModelConfigParams`
|
||||||
|
|
||||||
|
## Common Hook Patterns in Dify
|
||||||
|
|
||||||
|
### 1. Data Fetching Hook (React Query)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Use @tanstack/react-query for data fetching
|
||||||
|
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { get } from '@/service/base'
|
||||||
|
import { useInvalid } from '@/service/use-base'
|
||||||
|
|
||||||
|
const NAME_SPACE = 'appConfig'
|
||||||
|
|
||||||
|
// Query keys for cache management
|
||||||
|
export const appConfigQueryKeys = {
|
||||||
|
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main data hook
|
||||||
|
export const useAppConfig = (appId: string) => {
|
||||||
|
return useQuery({
|
||||||
|
enabled: !!appId,
|
||||||
|
queryKey: appConfigQueryKeys.detail(appId),
|
||||||
|
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||||
|
select: data => data?.model_config || null,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidation hook for refreshing data
|
||||||
|
export const useInvalidAppConfig = () => {
|
||||||
|
return useInvalid([NAME_SPACE])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage in component
|
||||||
|
const Component = () => {
|
||||||
|
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
|
||||||
|
const invalidAppConfig = useInvalidAppConfig()
|
||||||
|
|
||||||
|
const handleRefresh = () => {
|
||||||
|
invalidAppConfig() // Invalidates cache and triggers refetch
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Form State Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Form state + validation + submission
|
||||||
|
export const useConfigForm = (initialValues: ConfigFormValues) => {
|
||||||
|
const [values, setValues] = useState(initialValues)
|
||||||
|
const [errors, setErrors] = useState<Record<string, string>>({})
|
||||||
|
const [isSubmitting, setIsSubmitting] = useState(false)
|
||||||
|
|
||||||
|
const validate = useCallback(() => {
|
||||||
|
const newErrors: Record<string, string> = {}
|
||||||
|
if (!values.name) newErrors.name = 'Name is required'
|
||||||
|
setErrors(newErrors)
|
||||||
|
return Object.keys(newErrors).length === 0
|
||||||
|
}, [values])
|
||||||
|
|
||||||
|
const handleChange = useCallback((field: string, value: any) => {
|
||||||
|
setValues(prev => ({ ...prev, [field]: value }))
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
|
||||||
|
if (!validate()) return
|
||||||
|
setIsSubmitting(true)
|
||||||
|
try {
|
||||||
|
await onSubmit(values)
|
||||||
|
} finally {
|
||||||
|
setIsSubmitting(false)
|
||||||
|
}
|
||||||
|
}, [values, validate])
|
||||||
|
|
||||||
|
return { values, errors, isSubmitting, handleChange, handleSubmit }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Modal State Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Multiple modal management
|
||||||
|
type ModalType = 'edit' | 'delete' | 'duplicate' | null
|
||||||
|
|
||||||
|
export const useModalState = () => {
|
||||||
|
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||||
|
const [modalData, setModalData] = useState<any>(null)
|
||||||
|
|
||||||
|
const openModal = useCallback((type: ModalType, data?: any) => {
|
||||||
|
setActiveModal(type)
|
||||||
|
setModalData(data)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const closeModal = useCallback(() => {
|
||||||
|
setActiveModal(null)
|
||||||
|
setModalData(null)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
activeModal,
|
||||||
|
modalData,
|
||||||
|
openModal,
|
||||||
|
closeModal,
|
||||||
|
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Toggle/Boolean Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Boolean state with convenience methods
|
||||||
|
export const useToggle = (initialValue = false) => {
|
||||||
|
const [value, setValue] = useState(initialValue)
|
||||||
|
|
||||||
|
const toggle = useCallback(() => setValue(v => !v), [])
|
||||||
|
const setTrue = useCallback(() => setValue(true), [])
|
||||||
|
const setFalse = useCallback(() => setValue(false), [])
|
||||||
|
|
||||||
|
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage
|
||||||
|
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Extracted Hooks
|
||||||
|
|
||||||
|
After extraction, test hooks in isolation:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// use-model-config.spec.ts
|
||||||
|
import { renderHook, act } from '@testing-library/react'
|
||||||
|
import { useModelConfig } from './use-model-config'
|
||||||
|
|
||||||
|
describe('useModelConfig', () => {
|
||||||
|
it('should initialize with default values', () => {
|
||||||
|
const { result } = renderHook(() => useModelConfig({
|
||||||
|
hasFetchedDetail: false,
|
||||||
|
}))
|
||||||
|
|
||||||
|
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
|
||||||
|
expect(result.current.modelModeType).toBe(ModelModeType.unset)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update model config', () => {
|
||||||
|
const { result } = renderHook(() => useModelConfig({
|
||||||
|
hasFetchedDetail: true,
|
||||||
|
}))
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
result.current.setModelConfig({
|
||||||
|
...result.current.modelConfig,
|
||||||
|
model_id: 'gpt-4',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.current.modelConfig.model_id).toBe('gpt-4')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
---
|
---
|
||||||
name: frontend-testing
|
name: frontend-testing
|
||||||
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
|
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
|
||||||
---
|
---
|
||||||
|
|
||||||
# Dify Frontend Testing Skill
|
# Dify Frontend Testing Skill
|
||||||
|
|
||||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||||
|
|
||||||
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
|
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||||
|
|
||||||
## When to Apply This Skill
|
## When to Apply This Skill
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ Apply this skill when the user:
|
||||||
|
|
||||||
- Asks to **write tests** for a component, hook, or utility
|
- Asks to **write tests** for a component, hook, or utility
|
||||||
- Asks to **review existing tests** for completeness
|
- Asks to **review existing tests** for completeness
|
||||||
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
|
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
|
||||||
- Requests **test coverage** improvement
|
- Requests **test coverage** improvement
|
||||||
- Uses `pnpm analyze-component` output as context
|
- Uses `pnpm analyze-component` output as context
|
||||||
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
||||||
|
|
@ -33,9 +33,9 @@ Apply this skill when the user:
|
||||||
|
|
||||||
| Tool | Version | Purpose |
|
| Tool | Version | Purpose |
|
||||||
|------|---------|---------|
|
|------|---------|---------|
|
||||||
| Jest | 29.7 | Test runner |
|
| Vitest | 4.0.16 | Test runner |
|
||||||
| React Testing Library | 16.0 | Component testing |
|
| React Testing Library | 16.0 | Component testing |
|
||||||
| happy-dom | - | Test environment |
|
| jsdom | - | Test environment |
|
||||||
| nock | 14.0 | HTTP mocking |
|
| nock | 14.0 | HTTP mocking |
|
||||||
| TypeScript | 5.x | Type safety |
|
| TypeScript | 5.x | Type safety |
|
||||||
|
|
||||||
|
|
@ -46,13 +46,13 @@ Apply this skill when the user:
|
||||||
pnpm test
|
pnpm test
|
||||||
|
|
||||||
# Watch mode
|
# Watch mode
|
||||||
pnpm test -- --watch
|
pnpm test:watch
|
||||||
|
|
||||||
# Run specific file
|
# Run specific file
|
||||||
pnpm test -- path/to/file.spec.tsx
|
pnpm test path/to/file.spec.tsx
|
||||||
|
|
||||||
# Generate coverage report
|
# Generate coverage report
|
||||||
pnpm test -- --coverage
|
pnpm test:coverage
|
||||||
|
|
||||||
# Analyze component complexity
|
# Analyze component complexity
|
||||||
pnpm analyze-component <path>
|
pnpm analyze-component <path>
|
||||||
|
|
@ -77,9 +77,9 @@ import Component from './index'
|
||||||
// import { ChildComponent } from './child-component'
|
// import { ChildComponent } from './child-component'
|
||||||
|
|
||||||
// ✅ Mock external dependencies only
|
// ✅ Mock external dependencies only
|
||||||
jest.mock('@/service/api')
|
vi.mock('@/service/api')
|
||||||
jest.mock('next/navigation', () => ({
|
vi.mock('next/navigation', () => ({
|
||||||
useRouter: () => ({ push: jest.fn() }),
|
useRouter: () => ({ push: vi.fn() }),
|
||||||
usePathname: () => '/test',
|
usePathname: () => '/test',
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
@ -88,7 +88,7 @@ let mockSharedState = false
|
||||||
|
|
||||||
describe('ComponentName', () => {
|
describe('ComponentName', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
||||||
mockSharedState = false // ✅ Reset shared state
|
mockSharedState = false // ✅ Reset shared state
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -117,7 +117,7 @@ describe('ComponentName', () => {
|
||||||
// User Interactions
|
// User Interactions
|
||||||
describe('User Interactions', () => {
|
describe('User Interactions', () => {
|
||||||
it('should handle click events', () => {
|
it('should handle click events', () => {
|
||||||
const handleClick = jest.fn()
|
const handleClick = vi.fn()
|
||||||
render(<Component onClick={handleClick} />)
|
render(<Component onClick={handleClick} />)
|
||||||
|
|
||||||
fireEvent.click(screen.getByRole('button'))
|
fireEvent.click(screen.getByRole('button'))
|
||||||
|
|
@ -155,7 +155,7 @@ describe('ComponentName', () => {
|
||||||
For each file:
|
For each file:
|
||||||
┌────────────────────────────────────────┐
|
┌────────────────────────────────────────┐
|
||||||
│ 1. Write test │
|
│ 1. Write test │
|
||||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||||
│ 3. PASS? → Mark complete, next file │
|
│ 3. PASS? → Mark complete, next file │
|
||||||
│ FAIL? → Fix first, then continue │
|
│ FAIL? → Fix first, then continue │
|
||||||
└────────────────────────────────────────┘
|
└────────────────────────────────────────┘
|
||||||
|
|
@ -316,7 +316,7 @@ For more detailed information, refer to:
|
||||||
|
|
||||||
### Project Configuration
|
### Project Configuration
|
||||||
|
|
||||||
- `web/jest.config.ts` - Jest configuration
|
- `web/vitest.config.ts` - Vitest configuration
|
||||||
- `web/jest.setup.ts` - Test environment setup
|
- `web/vitest.setup.ts` - Test environment setup
|
||||||
- `web/testing/analyze-component.js` - Component analysis tool
|
- `web/scripts/analyze-component.js` - Component analysis tool
|
||||||
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)
|
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||||
|
|
|
||||||
|
|
@ -23,14 +23,14 @@ import userEvent from '@testing-library/user-event'
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Mocks
|
// Mocks
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// WHY: Mocks must be hoisted to top of file (Jest requirement).
|
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
|
||||||
// They run BEFORE imports, so keep them before component imports.
|
// They run BEFORE imports, so keep them before component imports.
|
||||||
|
|
||||||
// i18n (automatically mocked)
|
// i18n (automatically mocked)
|
||||||
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
|
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||||
// No explicit mock needed - it returns translation keys as-is
|
// No explicit mock needed - it returns translation keys as-is
|
||||||
// Override only if custom translations are required:
|
// Override only if custom translations are required:
|
||||||
// jest.mock('react-i18next', () => ({
|
// vi.mock('react-i18next', () => ({
|
||||||
// useTranslation: () => ({
|
// useTranslation: () => ({
|
||||||
// t: (key: string) => {
|
// t: (key: string) => {
|
||||||
// const customTranslations: Record<string, string> = {
|
// const customTranslations: Record<string, string> = {
|
||||||
|
|
@ -43,17 +43,17 @@ import userEvent from '@testing-library/user-event'
|
||||||
|
|
||||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||||
// const mockPush = jest.fn()
|
// const mockPush = vi.fn()
|
||||||
// jest.mock('next/navigation', () => ({
|
// vi.mock('next/navigation', () => ({
|
||||||
// useRouter: () => ({ push: mockPush }),
|
// useRouter: () => ({ push: mockPush }),
|
||||||
// usePathname: () => '/test-path',
|
// usePathname: () => '/test-path',
|
||||||
// }))
|
// }))
|
||||||
|
|
||||||
// API services (if component fetches data)
|
// API services (if component fetches data)
|
||||||
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
||||||
// jest.mock('@/service/api')
|
// vi.mock('@/service/api')
|
||||||
// import * as api from '@/service/api'
|
// import * as api from '@/service/api'
|
||||||
// const mockedApi = api as jest.Mocked<typeof api>
|
// const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
// Shared mock state (for portal/dropdown components)
|
// Shared mock state (for portal/dropdown components)
|
||||||
// WHY: Portal components like PortalToFollowElem need shared state between
|
// WHY: Portal components like PortalToFollowElem need shared state between
|
||||||
|
|
@ -98,7 +98,7 @@ describe('ComponentName', () => {
|
||||||
// - Prevents mock call history from leaking between tests
|
// - Prevents mock call history from leaking between tests
|
||||||
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
||||||
// mockOpenState = false
|
// mockOpenState = false
|
||||||
})
|
})
|
||||||
|
|
@ -155,7 +155,7 @@ describe('ComponentName', () => {
|
||||||
// - userEvent simulates real user behavior (focus, hover, then click)
|
// - userEvent simulates real user behavior (focus, hover, then click)
|
||||||
// - fireEvent is lower-level, doesn't trigger all browser events
|
// - fireEvent is lower-level, doesn't trigger all browser events
|
||||||
// const user = userEvent.setup()
|
// const user = userEvent.setup()
|
||||||
// const handleClick = jest.fn()
|
// const handleClick = vi.fn()
|
||||||
// render(<ComponentName onClick={handleClick} />)
|
// render(<ComponentName onClick={handleClick} />)
|
||||||
//
|
//
|
||||||
// await user.click(screen.getByRole('button'))
|
// await user.click(screen.getByRole('button'))
|
||||||
|
|
@ -165,7 +165,7 @@ describe('ComponentName', () => {
|
||||||
|
|
||||||
it('should call onChange when value changes', async () => {
|
it('should call onChange when value changes', async () => {
|
||||||
// const user = userEvent.setup()
|
// const user = userEvent.setup()
|
||||||
// const handleChange = jest.fn()
|
// const handleChange = vi.fn()
|
||||||
// render(<ComponentName onChange={handleChange} />)
|
// render(<ComponentName onChange={handleChange} />)
|
||||||
//
|
//
|
||||||
// await user.type(screen.getByRole('textbox'), 'new value')
|
// await user.type(screen.getByRole('textbox'), 'new value')
|
||||||
|
|
@ -198,7 +198,7 @@ describe('ComponentName', () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
|
// Async Operations (if component fetches data - useQuery, fetch)
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// WHY: Async operations have 3 states users experience: loading, success, error
|
// WHY: Async operations have 3 states users experience: loading, success, error
|
||||||
describe('Async Operations', () => {
|
describe('Async Operations', () => {
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,9 @@ import { renderHook, act, waitFor } from '@testing-library/react'
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
// API services (if hook fetches data)
|
// API services (if hook fetches data)
|
||||||
// jest.mock('@/service/api')
|
// vi.mock('@/service/api')
|
||||||
// import * as api from '@/service/api'
|
// import * as api from '@/service/api'
|
||||||
// const mockedApi = api as jest.Mocked<typeof api>
|
// const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Test Helpers
|
// Test Helpers
|
||||||
|
|
@ -38,7 +38,7 @@ import { renderHook, act, waitFor } from '@testing-library/react'
|
||||||
|
|
||||||
describe('useHookName', () => {
|
describe('useHookName', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
@ -145,7 +145,7 @@ describe('useHookName', () => {
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
describe('Side Effects', () => {
|
describe('Side Effects', () => {
|
||||||
it('should call callback when value changes', () => {
|
it('should call callback when value changes', () => {
|
||||||
// const callback = jest.fn()
|
// const callback = vi.fn()
|
||||||
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
||||||
//
|
//
|
||||||
// act(() => {
|
// act(() => {
|
||||||
|
|
@ -156,9 +156,9 @@ describe('useHookName', () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should cleanup on unmount', () => {
|
it('should cleanup on unmount', () => {
|
||||||
// const cleanup = jest.fn()
|
// const cleanup = vi.fn()
|
||||||
// jest.spyOn(window, 'addEventListener')
|
// vi.spyOn(window, 'addEventListener')
|
||||||
// jest.spyOn(window, 'removeEventListener')
|
// vi.spyOn(window, 'removeEventListener')
|
||||||
//
|
//
|
||||||
// const { unmount } = renderHook(() => useHookName())
|
// const { unmount } = renderHook(() => useHookName())
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ import userEvent from '@testing-library/user-event'
|
||||||
|
|
||||||
it('should submit form', async () => {
|
it('should submit form', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onSubmit = jest.fn()
|
const onSubmit = vi.fn()
|
||||||
|
|
||||||
render(<Form onSubmit={onSubmit} />)
|
render(<Form onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
|
@ -77,15 +77,15 @@ it('should submit form', async () => {
|
||||||
```typescript
|
```typescript
|
||||||
describe('Debounced Search', () => {
|
describe('Debounced Search', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
})
|
})
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
jest.useRealTimers()
|
vi.useRealTimers()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should debounce search input', async () => {
|
it('should debounce search input', async () => {
|
||||||
const onSearch = jest.fn()
|
const onSearch = vi.fn()
|
||||||
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
||||||
|
|
||||||
// Type in the input
|
// Type in the input
|
||||||
|
|
@ -95,7 +95,7 @@ describe('Debounced Search', () => {
|
||||||
expect(onSearch).not.toHaveBeenCalled()
|
expect(onSearch).not.toHaveBeenCalled()
|
||||||
|
|
||||||
// Advance timers
|
// Advance timers
|
||||||
jest.advanceTimersByTime(300)
|
vi.advanceTimersByTime(300)
|
||||||
|
|
||||||
// Now search is called
|
// Now search is called
|
||||||
expect(onSearch).toHaveBeenCalledWith('query')
|
expect(onSearch).toHaveBeenCalledWith('query')
|
||||||
|
|
@ -107,8 +107,8 @@ describe('Debounced Search', () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
it('should retry on failure', async () => {
|
it('should retry on failure', async () => {
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
const fetchData = jest.fn()
|
const fetchData = vi.fn()
|
||||||
.mockRejectedValueOnce(new Error('Network error'))
|
.mockRejectedValueOnce(new Error('Network error'))
|
||||||
.mockResolvedValueOnce({ data: 'success' })
|
.mockResolvedValueOnce({ data: 'success' })
|
||||||
|
|
||||||
|
|
@ -120,7 +120,7 @@ it('should retry on failure', async () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Advance timer for retry
|
// Advance timer for retry
|
||||||
jest.advanceTimersByTime(1000)
|
vi.advanceTimersByTime(1000)
|
||||||
|
|
||||||
// Second call succeeds
|
// Second call succeeds
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
|
|
@ -128,7 +128,7 @@ it('should retry on failure', async () => {
|
||||||
expect(screen.getByText('success')).toBeInTheDocument()
|
expect(screen.getByText('success')).toBeInTheDocument()
|
||||||
})
|
})
|
||||||
|
|
||||||
jest.useRealTimers()
|
vi.useRealTimers()
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -136,19 +136,19 @@ it('should retry on failure', async () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// Run all pending timers
|
// Run all pending timers
|
||||||
jest.runAllTimers()
|
vi.runAllTimers()
|
||||||
|
|
||||||
// Run only pending timers (not new ones created during execution)
|
// Run only pending timers (not new ones created during execution)
|
||||||
jest.runOnlyPendingTimers()
|
vi.runOnlyPendingTimers()
|
||||||
|
|
||||||
// Advance by specific time
|
// Advance by specific time
|
||||||
jest.advanceTimersByTime(1000)
|
vi.advanceTimersByTime(1000)
|
||||||
|
|
||||||
// Get current fake time
|
// Get current fake time
|
||||||
jest.now()
|
Date.now()
|
||||||
|
|
||||||
// Clear all timers
|
// Clear all timers
|
||||||
jest.clearAllTimers()
|
vi.clearAllTimers()
|
||||||
```
|
```
|
||||||
|
|
||||||
## API Testing Patterns
|
## API Testing Patterns
|
||||||
|
|
@ -158,7 +158,7 @@ jest.clearAllTimers()
|
||||||
```typescript
|
```typescript
|
||||||
describe('DataFetcher', () => {
|
describe('DataFetcher', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should show loading state', () => {
|
it('should show loading state', () => {
|
||||||
|
|
@ -241,7 +241,7 @@ it('should submit form and show success', async () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
it('should fetch data on mount', async () => {
|
it('should fetch data on mount', async () => {
|
||||||
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
|
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||||
|
|
||||||
render(<ComponentWithEffect fetchData={fetchData} />)
|
render(<ComponentWithEffect fetchData={fetchData} />)
|
||||||
|
|
||||||
|
|
@ -255,7 +255,7 @@ it('should fetch data on mount', async () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
it('should refetch when id changes', async () => {
|
it('should refetch when id changes', async () => {
|
||||||
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
|
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||||
|
|
||||||
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
||||||
|
|
||||||
|
|
@ -276,8 +276,8 @@ it('should refetch when id changes', async () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
it('should cleanup subscription on unmount', () => {
|
it('should cleanup subscription on unmount', () => {
|
||||||
const subscribe = jest.fn()
|
const subscribe = vi.fn()
|
||||||
const unsubscribe = jest.fn()
|
const unsubscribe = vi.fn()
|
||||||
subscribe.mockReturnValue(unsubscribe)
|
subscribe.mockReturnValue(unsubscribe)
|
||||||
|
|
||||||
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
||||||
|
|
@ -332,14 +332,14 @@ expect(description).toBeInTheDocument()
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// Bad - fake timers don't work well with real Promises
|
// Bad - fake timers don't work well with real Promises
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
}) // May timeout!
|
}) // May timeout!
|
||||||
|
|
||||||
// Good - use runAllTimers or advanceTimersByTime
|
// Good - use runAllTimers or advanceTimersByTime
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
render(<Component />)
|
render(<Component />)
|
||||||
jest.runAllTimers()
|
vi.runAllTimers()
|
||||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -74,9 +74,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
|
||||||
### Mocks
|
### Mocks
|
||||||
|
|
||||||
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||||
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||||
- [ ] Shared mock state reset in `beforeEach`
|
- [ ] Shared mock state reset in `beforeEach`
|
||||||
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
|
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
|
||||||
- [ ] Router mocks match actual Next.js API
|
- [ ] Router mocks match actual Next.js API
|
||||||
- [ ] Mocks reflect actual component conditional behavior
|
- [ ] Mocks reflect actual component conditional behavior
|
||||||
- [ ] Only mock: API services, complex context providers, third-party libs
|
- [ ] Only mock: API services, complex context providers, third-party libs
|
||||||
|
|
@ -114,15 +114,15 @@ For the current file being tested:
|
||||||
|
|
||||||
**Run these checks after EACH test file, not just at the end:**
|
**Run these checks after EACH test file, not just at the end:**
|
||||||
|
|
||||||
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
|
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
|
||||||
- [ ] Fix any failures immediately
|
- [ ] Fix any failures immediately
|
||||||
- [ ] Mark file as complete in todo list
|
- [ ] Mark file as complete in todo list
|
||||||
- [ ] Only then proceed to next file
|
- [ ] Only then proceed to next file
|
||||||
|
|
||||||
### After All Files Complete
|
### After All Files Complete
|
||||||
|
|
||||||
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
|
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||||
- [ ] Check coverage report: `pnpm test -- --coverage`
|
- [ ] Check coverage report: `pnpm test:coverage`
|
||||||
- [ ] Run `pnpm lint:fix` on all test files
|
- [ ] Run `pnpm lint:fix` on all test files
|
||||||
- [ ] Run `pnpm type-check:tsgo`
|
- [ ] Run `pnpm type-check:tsgo`
|
||||||
|
|
||||||
|
|
@ -132,10 +132,10 @@ For the current file being tested:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// ❌ Mock doesn't match actual behavior
|
// ❌ Mock doesn't match actual behavior
|
||||||
jest.mock('./Component', () => () => <div>Mocked</div>)
|
vi.mock('./Component', () => () => <div>Mocked</div>)
|
||||||
|
|
||||||
// ✅ Mock matches actual conditional logic
|
// ✅ Mock matches actual conditional logic
|
||||||
jest.mock('./Component', () => ({ isOpen }: any) =>
|
vi.mock('./Component', () => ({ isOpen }: any) =>
|
||||||
isOpen ? <div>Content</div> : null
|
isOpen ? <div>Content</div> : null
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
@ -145,7 +145,7 @@ jest.mock('./Component', () => ({ isOpen }: any) =>
|
||||||
```typescript
|
```typescript
|
||||||
// ❌ Shared state not reset
|
// ❌ Shared state not reset
|
||||||
let mockState = false
|
let mockState = false
|
||||||
jest.mock('./useHook', () => () => mockState)
|
vi.mock('./useHook', () => () => mockState)
|
||||||
|
|
||||||
// ✅ Reset in beforeEach
|
// ✅ Reset in beforeEach
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
|
|
@ -186,16 +186,16 @@ Always test these scenarios:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run specific test
|
# Run specific test
|
||||||
pnpm test -- path/to/file.spec.tsx
|
pnpm test path/to/file.spec.tsx
|
||||||
|
|
||||||
# Run with coverage
|
# Run with coverage
|
||||||
pnpm test -- --coverage path/to/file.spec.tsx
|
pnpm test:coverage path/to/file.spec.tsx
|
||||||
|
|
||||||
# Watch mode
|
# Watch mode
|
||||||
pnpm test -- --watch path/to/file.spec.tsx
|
pnpm test:watch path/to/file.spec.tsx
|
||||||
|
|
||||||
# Update snapshots (use sparingly)
|
# Update snapshots (use sparingly)
|
||||||
pnpm test -- -u path/to/file.spec.tsx
|
pnpm test -u path/to/file.spec.tsx
|
||||||
|
|
||||||
# Analyze component
|
# Analyze component
|
||||||
pnpm analyze-component path/to/component.tsx
|
pnpm analyze-component path/to/component.tsx
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,7 @@ describe('Counter', () => {
|
||||||
describe('ControlledInput', () => {
|
describe('ControlledInput', () => {
|
||||||
it('should call onChange with new value', async () => {
|
it('should call onChange with new value', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const handleChange = jest.fn()
|
const handleChange = vi.fn()
|
||||||
|
|
||||||
render(<ControlledInput value="" onChange={handleChange} />)
|
render(<ControlledInput value="" onChange={handleChange} />)
|
||||||
|
|
||||||
|
|
@ -136,7 +136,7 @@ describe('ControlledInput', () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should display controlled value', () => {
|
it('should display controlled value', () => {
|
||||||
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
|
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
|
||||||
|
|
||||||
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
||||||
})
|
})
|
||||||
|
|
@ -195,7 +195,7 @@ describe('ItemList', () => {
|
||||||
|
|
||||||
it('should handle item selection', async () => {
|
it('should handle item selection', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onSelect = jest.fn()
|
const onSelect = vi.fn()
|
||||||
|
|
||||||
render(<ItemList items={items} onSelect={onSelect} />)
|
render(<ItemList items={items} onSelect={onSelect} />)
|
||||||
|
|
||||||
|
|
@ -217,20 +217,20 @@ describe('ItemList', () => {
|
||||||
```typescript
|
```typescript
|
||||||
describe('Modal', () => {
|
describe('Modal', () => {
|
||||||
it('should not render when closed', () => {
|
it('should not render when closed', () => {
|
||||||
render(<Modal isOpen={false} onClose={jest.fn()} />)
|
render(<Modal isOpen={false} onClose={vi.fn()} />)
|
||||||
|
|
||||||
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should render when open', () => {
|
it('should render when open', () => {
|
||||||
render(<Modal isOpen={true} onClose={jest.fn()} />)
|
render(<Modal isOpen={true} onClose={vi.fn()} />)
|
||||||
|
|
||||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should call onClose when clicking overlay', async () => {
|
it('should call onClose when clicking overlay', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const handleClose = jest.fn()
|
const handleClose = vi.fn()
|
||||||
|
|
||||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||||
|
|
||||||
|
|
@ -241,7 +241,7 @@ describe('Modal', () => {
|
||||||
|
|
||||||
it('should call onClose when pressing Escape', async () => {
|
it('should call onClose when pressing Escape', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const handleClose = jest.fn()
|
const handleClose = vi.fn()
|
||||||
|
|
||||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||||
|
|
||||||
|
|
@ -254,7 +254,7 @@ describe('Modal', () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
|
|
||||||
render(
|
render(
|
||||||
<Modal isOpen={true} onClose={jest.fn()}>
|
<Modal isOpen={true} onClose={vi.fn()}>
|
||||||
<button>First</button>
|
<button>First</button>
|
||||||
<button>Second</button>
|
<button>Second</button>
|
||||||
</Modal>
|
</Modal>
|
||||||
|
|
@ -279,7 +279,7 @@ describe('Modal', () => {
|
||||||
describe('LoginForm', () => {
|
describe('LoginForm', () => {
|
||||||
it('should submit valid form', async () => {
|
it('should submit valid form', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onSubmit = jest.fn()
|
const onSubmit = vi.fn()
|
||||||
|
|
||||||
render(<LoginForm onSubmit={onSubmit} />)
|
render(<LoginForm onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
|
@ -296,7 +296,7 @@ describe('LoginForm', () => {
|
||||||
it('should show validation errors', async () => {
|
it('should show validation errors', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
|
|
||||||
render(<LoginForm onSubmit={jest.fn()} />)
|
render(<LoginForm onSubmit={vi.fn()} />)
|
||||||
|
|
||||||
// Submit empty form
|
// Submit empty form
|
||||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||||
|
|
@ -308,7 +308,7 @@ describe('LoginForm', () => {
|
||||||
it('should validate email format', async () => {
|
it('should validate email format', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
|
|
||||||
render(<LoginForm onSubmit={jest.fn()} />)
|
render(<LoginForm onSubmit={vi.fn()} />)
|
||||||
|
|
||||||
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
||||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||||
|
|
@ -318,7 +318,7 @@ describe('LoginForm', () => {
|
||||||
|
|
||||||
it('should disable submit button while submitting', async () => {
|
it('should disable submit button while submitting', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
||||||
|
|
||||||
render(<LoginForm onSubmit={onSubmit} />)
|
render(<LoginForm onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
|
@ -407,7 +407,7 @@ it('test 1', () => {
|
||||||
|
|
||||||
// Good - cleanup is automatic with RTL, but reset mocks
|
// Good - cleanup is automatic with RTL, but reset mocks
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import NodeConfigPanel from './node-config-panel'
|
||||||
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
||||||
|
|
||||||
// Mock workflow context
|
// Mock workflow context
|
||||||
jest.mock('@/app/components/workflow/hooks', () => ({
|
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||||
useWorkflowStore: () => mockWorkflowStore,
|
useWorkflowStore: () => mockWorkflowStore,
|
||||||
useNodesInteractions: () => mockNodesInteractions,
|
useNodesInteractions: () => mockNodesInteractions,
|
||||||
}))
|
}))
|
||||||
|
|
@ -31,21 +31,21 @@ jest.mock('@/app/components/workflow/hooks', () => ({
|
||||||
let mockWorkflowStore = {
|
let mockWorkflowStore = {
|
||||||
nodes: [],
|
nodes: [],
|
||||||
edges: [],
|
edges: [],
|
||||||
updateNode: jest.fn(),
|
updateNode: vi.fn(),
|
||||||
}
|
}
|
||||||
|
|
||||||
let mockNodesInteractions = {
|
let mockNodesInteractions = {
|
||||||
handleNodeSelect: jest.fn(),
|
handleNodeSelect: vi.fn(),
|
||||||
handleNodeDelete: jest.fn(),
|
handleNodeDelete: vi.fn(),
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('NodeConfigPanel', () => {
|
describe('NodeConfigPanel', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
mockWorkflowStore = {
|
mockWorkflowStore = {
|
||||||
nodes: [],
|
nodes: [],
|
||||||
edges: [],
|
edges: [],
|
||||||
updateNode: jest.fn(),
|
updateNode: vi.fn(),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -161,23 +161,23 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
import userEvent from '@testing-library/user-event'
|
import userEvent from '@testing-library/user-event'
|
||||||
import DocumentUploader from './document-uploader'
|
import DocumentUploader from './document-uploader'
|
||||||
|
|
||||||
jest.mock('@/service/datasets', () => ({
|
vi.mock('@/service/datasets', () => ({
|
||||||
uploadDocument: jest.fn(),
|
uploadDocument: vi.fn(),
|
||||||
parseDocument: jest.fn(),
|
parseDocument: vi.fn(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
import * as datasetService from '@/service/datasets'
|
import * as datasetService from '@/service/datasets'
|
||||||
const mockedService = datasetService as jest.Mocked<typeof datasetService>
|
const mockedService = vi.mocked(datasetService)
|
||||||
|
|
||||||
describe('DocumentUploader', () => {
|
describe('DocumentUploader', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('File Upload', () => {
|
describe('File Upload', () => {
|
||||||
it('should accept valid file types', async () => {
|
it('should accept valid file types', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onUpload = jest.fn()
|
const onUpload = vi.fn()
|
||||||
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
||||||
|
|
||||||
render(<DocumentUploader onUpload={onUpload} />)
|
render(<DocumentUploader onUpload={onUpload} />)
|
||||||
|
|
@ -326,14 +326,14 @@ describe('DocumentList', () => {
|
||||||
describe('Search & Filtering', () => {
|
describe('Search & Filtering', () => {
|
||||||
it('should filter by search query', async () => {
|
it('should filter by search query', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
|
|
||||||
render(<DocumentList datasetId="ds-1" />)
|
render(<DocumentList datasetId="ds-1" />)
|
||||||
|
|
||||||
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
||||||
|
|
||||||
// Debounce
|
// Debounce
|
||||||
jest.advanceTimersByTime(300)
|
vi.advanceTimersByTime(300)
|
||||||
|
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
||||||
|
|
@ -342,7 +342,7 @@ describe('DocumentList', () => {
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
jest.useRealTimers()
|
vi.useRealTimers()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
@ -367,13 +367,13 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
import userEvent from '@testing-library/user-event'
|
import userEvent from '@testing-library/user-event'
|
||||||
import AppConfigForm from './app-config-form'
|
import AppConfigForm from './app-config-form'
|
||||||
|
|
||||||
jest.mock('@/service/apps', () => ({
|
vi.mock('@/service/apps', () => ({
|
||||||
updateAppConfig: jest.fn(),
|
updateAppConfig: vi.fn(),
|
||||||
getAppConfig: jest.fn(),
|
getAppConfig: vi.fn(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
import * as appService from '@/service/apps'
|
import * as appService from '@/service/apps'
|
||||||
const mockedService = appService as jest.Mocked<typeof appService>
|
const mockedService = vi.mocked(appService)
|
||||||
|
|
||||||
describe('AppConfigForm', () => {
|
describe('AppConfigForm', () => {
|
||||||
const defaultConfig = {
|
const defaultConfig = {
|
||||||
|
|
@ -384,7 +384,7 @@ describe('AppConfigForm', () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// ❌ WRONG: Don't mock base components
|
// ❌ WRONG: Don't mock base components
|
||||||
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||||
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||||
|
|
||||||
// ✅ CORRECT: Import and use real base components
|
// ✅ CORRECT: Import and use real base components
|
||||||
import Loading from '@/app/components/base/loading'
|
import Loading from '@/app/components/base/loading'
|
||||||
|
|
@ -41,20 +41,23 @@ Only mock these categories:
|
||||||
|
|
||||||
| Location | Purpose |
|
| Location | Purpose |
|
||||||
|----------|---------|
|
|----------|---------|
|
||||||
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
|
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
|
||||||
| Test file | Test-specific mocks, inline with `jest.mock()` |
|
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
|
||||||
|
| Test file | Test-specific mocks, inline with `vi.mock()` |
|
||||||
|
|
||||||
|
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
|
||||||
|
|
||||||
## Essential Mocks
|
## Essential Mocks
|
||||||
|
|
||||||
### 1. i18n (Auto-loaded via Shared Mock)
|
### 1. i18n (Auto-loaded via Global Mock)
|
||||||
|
|
||||||
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
|
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||||
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
||||||
|
|
||||||
For tests requiring custom translations, override the mock:
|
For tests requiring custom translations, override the mock:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
jest.mock('react-i18next', () => ({
|
vi.mock('react-i18next', () => ({
|
||||||
useTranslation: () => ({
|
useTranslation: () => ({
|
||||||
t: (key: string) => {
|
t: (key: string) => {
|
||||||
const translations: Record<string, string> = {
|
const translations: Record<string, string> = {
|
||||||
|
|
@ -69,15 +72,15 @@ jest.mock('react-i18next', () => ({
|
||||||
### 2. Next.js Router
|
### 2. Next.js Router
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
const mockPush = jest.fn()
|
const mockPush = vi.fn()
|
||||||
const mockReplace = jest.fn()
|
const mockReplace = vi.fn()
|
||||||
|
|
||||||
jest.mock('next/navigation', () => ({
|
vi.mock('next/navigation', () => ({
|
||||||
useRouter: () => ({
|
useRouter: () => ({
|
||||||
push: mockPush,
|
push: mockPush,
|
||||||
replace: mockReplace,
|
replace: mockReplace,
|
||||||
back: jest.fn(),
|
back: vi.fn(),
|
||||||
prefetch: jest.fn(),
|
prefetch: vi.fn(),
|
||||||
}),
|
}),
|
||||||
usePathname: () => '/current-path',
|
usePathname: () => '/current-path',
|
||||||
useSearchParams: () => new URLSearchParams('?key=value'),
|
useSearchParams: () => new URLSearchParams('?key=value'),
|
||||||
|
|
@ -85,7 +88,7 @@ jest.mock('next/navigation', () => ({
|
||||||
|
|
||||||
describe('Component', () => {
|
describe('Component', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should navigate on click', () => {
|
it('should navigate on click', () => {
|
||||||
|
|
@ -102,7 +105,7 @@ describe('Component', () => {
|
||||||
// ⚠️ Important: Use shared state for components that depend on each other
|
// ⚠️ Important: Use shared state for components that depend on each other
|
||||||
let mockPortalOpenState = false
|
let mockPortalOpenState = false
|
||||||
|
|
||||||
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||||
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
||||||
mockPortalOpenState = open || false // Update shared state
|
mockPortalOpenState = open || false // Update shared state
|
||||||
return <div data-testid="portal" data-open={open}>{children}</div>
|
return <div data-testid="portal" data-open={open}>{children}</div>
|
||||||
|
|
@ -119,7 +122,7 @@ jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||||
|
|
||||||
describe('Component', () => {
|
describe('Component', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
mockPortalOpenState = false // ✅ Reset shared state
|
mockPortalOpenState = false // ✅ Reset shared state
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
@ -130,13 +133,13 @@ describe('Component', () => {
|
||||||
```typescript
|
```typescript
|
||||||
import * as api from '@/service/api'
|
import * as api from '@/service/api'
|
||||||
|
|
||||||
jest.mock('@/service/api')
|
vi.mock('@/service/api')
|
||||||
|
|
||||||
const mockedApi = api as jest.Mocked<typeof api>
|
const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
describe('Component', () => {
|
describe('Component', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
|
|
||||||
// Setup default mock implementation
|
// Setup default mock implementation
|
||||||
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
||||||
|
|
@ -239,32 +242,9 @@ describe('Component with Context', () => {
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
### 7. SWR / React Query
|
### 7. React Query
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// SWR
|
|
||||||
jest.mock('swr', () => ({
|
|
||||||
__esModule: true,
|
|
||||||
default: jest.fn(),
|
|
||||||
}))
|
|
||||||
|
|
||||||
import useSWR from 'swr'
|
|
||||||
const mockedUseSWR = useSWR as jest.Mock
|
|
||||||
|
|
||||||
describe('Component with SWR', () => {
|
|
||||||
it('should show loading state', () => {
|
|
||||||
mockedUseSWR.mockReturnValue({
|
|
||||||
data: undefined,
|
|
||||||
error: undefined,
|
|
||||||
isLoading: true,
|
|
||||||
})
|
|
||||||
|
|
||||||
render(<Component />)
|
|
||||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// React Query
|
|
||||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||||
|
|
||||||
const createTestQueryClient = () => new QueryClient({
|
const createTestQueryClient = () => new QueryClient({
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ When testing a **single component, hook, or utility**:
|
||||||
2. Run `pnpm analyze-component <path>` (if available)
|
2. Run `pnpm analyze-component <path>` (if available)
|
||||||
3. Check complexity score and features detected
|
3. Check complexity score and features detected
|
||||||
4. Write the test file
|
4. Write the test file
|
||||||
5. Run test: `pnpm test -- <file>.spec.tsx`
|
5. Run test: `pnpm test <file>.spec.tsx`
|
||||||
6. Fix any failures
|
6. Fix any failures
|
||||||
7. Verify coverage meets goals (100% function, >95% branch)
|
7. Verify coverage meets goals (100% function, >95% branch)
|
||||||
```
|
```
|
||||||
|
|
@ -80,7 +80,7 @@ Process files in this recommended order:
|
||||||
```
|
```
|
||||||
┌─────────────────────────────────────────────┐
|
┌─────────────────────────────────────────────┐
|
||||||
│ 1. Write test file │
|
│ 1. Write test file │
|
||||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||||
│ 3. If FAIL → Fix immediately, re-run │
|
│ 3. If FAIL → Fix immediately, re-run │
|
||||||
│ 4. If PASS → Mark complete in todo list │
|
│ 4. If PASS → Mark complete in todo list │
|
||||||
│ 5. ONLY THEN proceed to next file │
|
│ 5. ONLY THEN proceed to next file │
|
||||||
|
|
@ -95,10 +95,10 @@ After all individual tests pass:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run all tests in the directory together
|
# Run all tests in the directory together
|
||||||
pnpm test -- path/to/directory/
|
pnpm test path/to/directory/
|
||||||
|
|
||||||
# Check coverage
|
# Check coverage
|
||||||
pnpm test -- --coverage path/to/directory/
|
pnpm test:coverage path/to/directory/
|
||||||
```
|
```
|
||||||
|
|
||||||
## Component Complexity Guidelines
|
## Component Complexity Guidelines
|
||||||
|
|
@ -201,9 +201,9 @@ Run pnpm test ← Multiple failures, hard to debug
|
||||||
```
|
```
|
||||||
# GOOD: Incremental with verification
|
# GOOD: Incremental with verification
|
||||||
Write component-a.spec.tsx
|
Write component-a.spec.tsx
|
||||||
Run pnpm test -- component-a.spec.tsx ✅
|
Run pnpm test component-a.spec.tsx ✅
|
||||||
Write component-b.spec.tsx
|
Write component-b.spec.tsx
|
||||||
Run pnpm test -- component-b.spec.tsx ✅
|
Run pnpm test component-b.spec.tsx ✅
|
||||||
...continue...
|
...continue...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,9 @@
|
||||||
"context": "..",
|
"context": "..",
|
||||||
"dockerfile": "Dockerfile"
|
"dockerfile": "Dockerfile"
|
||||||
},
|
},
|
||||||
|
"mounts": [
|
||||||
|
"source=dify-dev-tmp,target=/tmp,type=volume"
|
||||||
|
],
|
||||||
"features": {
|
"features": {
|
||||||
"ghcr.io/devcontainers/features/node:1": {
|
"ghcr.io/devcontainers/features/node:1": {
|
||||||
"nodeGypDependencies": true,
|
"nodeGypDependencies": true,
|
||||||
|
|
@ -34,19 +37,13 @@
|
||||||
},
|
},
|
||||||
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
||||||
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
||||||
|
|
||||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||||
// "features": {},
|
// "features": {},
|
||||||
|
|
||||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||||
// "forwardPorts": [],
|
// "forwardPorts": [],
|
||||||
|
|
||||||
// Use 'postCreateCommand' to run commands after the container is created.
|
// Use 'postCreateCommand' to run commands after the container is created.
|
||||||
// "postCreateCommand": "python --version",
|
// "postCreateCommand": "python --version",
|
||||||
|
|
||||||
// Configure tool-specific properties.
|
// Configure tool-specific properties.
|
||||||
// "customizations": {},
|
// "customizations": {},
|
||||||
|
|
||||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||||
// "remoteUser": "root"
|
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
WORKSPACE_ROOT=$(pwd)
|
WORKSPACE_ROOT=$(pwd)
|
||||||
|
|
||||||
|
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
|
||||||
corepack enable
|
corepack enable
|
||||||
cd web && pnpm install
|
cd web && pnpm install
|
||||||
pipx install uv
|
pipx install uv
|
||||||
|
|
|
||||||
|
|
@ -7,244 +7,243 @@
|
||||||
* @crazywoola @laipz8200 @Yeuoly
|
* @crazywoola @laipz8200 @Yeuoly
|
||||||
|
|
||||||
# CODEOWNERS file
|
# CODEOWNERS file
|
||||||
.github/CODEOWNERS @laipz8200 @crazywoola
|
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||||
|
|
||||||
# Docs
|
# Docs
|
||||||
docs/ @crazywoola
|
/docs/ @crazywoola
|
||||||
|
|
||||||
# Backend (default owner, more specific rules below will override)
|
# Backend (default owner, more specific rules below will override)
|
||||||
api/ @QuantumGhost
|
/api/ @QuantumGhost
|
||||||
|
|
||||||
# Backend - MCP
|
# Backend - MCP
|
||||||
api/core/mcp/ @Nov1c444
|
/api/core/mcp/ @Nov1c444
|
||||||
api/core/entities/mcp_provider.py @Nov1c444
|
/api/core/entities/mcp_provider.py @Nov1c444
|
||||||
api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||||
api/controllers/mcp/ @Nov1c444
|
/api/controllers/mcp/ @Nov1c444
|
||||||
api/controllers/console/app/mcp_server.py @Nov1c444
|
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||||
api/tests/**/*mcp* @Nov1c444
|
/api/tests/**/*mcp* @Nov1c444
|
||||||
|
|
||||||
# Backend - Workflow - Engine (Core graph execution engine)
|
# Backend - Workflow - Engine (Core graph execution engine)
|
||||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
/api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||||
|
|
||||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||||
api/core/workflow/nodes/agent/ @Nov1c444
|
/api/core/workflow/nodes/agent/ @Nov1c444
|
||||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
/api/core/workflow/nodes/iteration/ @Nov1c444
|
||||||
api/core/workflow/nodes/loop/ @Nov1c444
|
/api/core/workflow/nodes/loop/ @Nov1c444
|
||||||
api/core/workflow/nodes/llm/ @Nov1c444
|
/api/core/workflow/nodes/llm/ @Nov1c444
|
||||||
|
|
||||||
# Backend - RAG (Retrieval Augmented Generation)
|
# Backend - RAG (Retrieval Augmented Generation)
|
||||||
api/core/rag/ @JohnJyong
|
/api/core/rag/ @JohnJyong
|
||||||
api/services/rag_pipeline/ @JohnJyong
|
/api/services/rag_pipeline/ @JohnJyong
|
||||||
api/services/dataset_service.py @JohnJyong
|
/api/services/dataset_service.py @JohnJyong
|
||||||
api/services/knowledge_service.py @JohnJyong
|
/api/services/knowledge_service.py @JohnJyong
|
||||||
api/services/external_knowledge_service.py @JohnJyong
|
/api/services/external_knowledge_service.py @JohnJyong
|
||||||
api/services/hit_testing_service.py @JohnJyong
|
/api/services/hit_testing_service.py @JohnJyong
|
||||||
api/services/metadata_service.py @JohnJyong
|
/api/services/metadata_service.py @JohnJyong
|
||||||
api/services/vector_service.py @JohnJyong
|
/api/services/vector_service.py @JohnJyong
|
||||||
api/services/entities/knowledge_entities/ @JohnJyong
|
/api/services/entities/knowledge_entities/ @JohnJyong
|
||||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
/api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||||
api/controllers/console/datasets/ @JohnJyong
|
/api/controllers/console/datasets/ @JohnJyong
|
||||||
api/controllers/service_api/dataset/ @JohnJyong
|
/api/controllers/service_api/dataset/ @JohnJyong
|
||||||
api/models/dataset.py @JohnJyong
|
/api/models/dataset.py @JohnJyong
|
||||||
api/tasks/rag_pipeline/ @JohnJyong
|
/api/tasks/rag_pipeline/ @JohnJyong
|
||||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
/api/tasks/add_document_to_index_task.py @JohnJyong
|
||||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
/api/tasks/batch_clean_document_task.py @JohnJyong
|
||||||
api/tasks/clean_document_task.py @JohnJyong
|
/api/tasks/clean_document_task.py @JohnJyong
|
||||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
/api/tasks/clean_notion_document_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_task.py @JohnJyong
|
/api/tasks/document_indexing_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
/api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
/api/tasks/document_indexing_update_task.py @JohnJyong
|
||||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
/api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
/api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
/api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
/api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
/api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
/api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
/api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
/api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
/api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
/api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
/api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||||
api/tasks/clean_dataset_task.py @JohnJyong
|
/api/tasks/clean_dataset_task.py @JohnJyong
|
||||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
/api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||||
|
|
||||||
# Backend - Plugins
|
# Backend - Plugins
|
||||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||||
|
|
||||||
# Backend - Trigger/Schedule/Webhook
|
# Backend - Trigger/Schedule/Webhook
|
||||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
/api/controllers/trigger/ @Mairuis @Yeuoly
|
||||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||||
api/core/trigger/ @Mairuis @Yeuoly
|
/api/core/trigger/ @Mairuis @Yeuoly
|
||||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||||
api/services/trigger/ @Mairuis @Yeuoly
|
/api/services/trigger/ @Mairuis @Yeuoly
|
||||||
api/models/trigger.py @Mairuis @Yeuoly
|
/api/models/trigger.py @Mairuis @Yeuoly
|
||||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
/api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||||
|
|
||||||
# Backend - Async Workflow
|
# Backend - Async Workflow
|
||||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
/api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||||
|
|
||||||
# Backend - Billing
|
# Backend - Billing
|
||||||
api/services/billing_service.py @hj24 @zyssyz123
|
/api/services/billing_service.py @hj24 @zyssyz123
|
||||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
/api/controllers/console/billing/ @hj24 @zyssyz123
|
||||||
|
|
||||||
# Backend - Enterprise
|
# Backend - Enterprise
|
||||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
/api/configs/enterprise/ @GarfieldDai @GareArc
|
||||||
api/services/enterprise/ @GarfieldDai @GareArc
|
/api/services/enterprise/ @GarfieldDai @GareArc
|
||||||
api/services/feature_service.py @GarfieldDai @GareArc
|
/api/services/feature_service.py @GarfieldDai @GareArc
|
||||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
/api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
/api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||||
|
|
||||||
# Backend - Database Migrations
|
# Backend - Database Migrations
|
||||||
api/migrations/ @snakevash @laipz8200 @MRZHUH
|
/api/migrations/ @snakevash @laipz8200 @MRZHUH
|
||||||
|
|
||||||
# Backend - Vector DB Middleware
|
# Backend - Vector DB Middleware
|
||||||
api/configs/middleware/vdb/* @JohnJyong
|
/api/configs/middleware/vdb/* @JohnJyong
|
||||||
|
|
||||||
# Frontend
|
# Frontend
|
||||||
web/ @iamjoel
|
/web/ @iamjoel
|
||||||
|
|
||||||
# Frontend - Web Tests
|
# Frontend - Web Tests
|
||||||
.github/workflows/web-tests.yml @iamjoel
|
/.github/workflows/web-tests.yml @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Orchestration
|
# Frontend - App - Orchestration
|
||||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
/web/app/components/workflow/ @iamjoel @zxhlyh
|
||||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
/web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
/web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - WebApp - Chat
|
# Frontend - WebApp - Chat
|
||||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
/web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - WebApp - Completion
|
# Frontend - WebApp - Completion
|
||||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
/web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - App - List and Creation
|
# Frontend - App - List and Creation
|
||||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
/web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - API Documentation
|
# Frontend - App - API Documentation
|
||||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Logs and Annotations
|
# Frontend - App - Logs and Annotations
|
||||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
/web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Monitoring
|
# Frontend - App - Monitoring
|
||||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
/web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Settings
|
# Frontend - App - Settings
|
||||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - RAG - Hit Testing
|
# Frontend - RAG - Hit Testing
|
||||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - RAG - List and Creation
|
# Frontend - RAG - List and Creation
|
||||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
/web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
/web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
/web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - RAG - Documents List
|
# Frontend - RAG - Documents List
|
||||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Segments List
|
# Frontend - RAG - Segments List
|
||||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Settings
|
# Frontend - RAG - Settings
|
||||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
/web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - Ecosystem - Plugins
|
# Frontend - Ecosystem - Plugins
|
||||||
web/app/components/plugins/ @iamjoel @zhsama
|
/web/app/components/plugins/ @iamjoel @zhsama
|
||||||
|
|
||||||
# Frontend - Ecosystem - Tools
|
# Frontend - Ecosystem - Tools
|
||||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
/web/app/components/tools/ @iamjoel @Yessenia-d
|
||||||
|
|
||||||
# Frontend - Ecosystem - MarketPlace
|
# Frontend - Ecosystem - MarketPlace
|
||||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||||
|
|
||||||
# Frontend - Login and Registration
|
# Frontend - Login and Registration
|
||||||
web/app/signin/ @douxc @iamjoel
|
/web/app/signin/ @douxc @iamjoel
|
||||||
web/app/signup/ @douxc @iamjoel
|
/web/app/signup/ @douxc @iamjoel
|
||||||
web/app/reset-password/ @douxc @iamjoel
|
/web/app/reset-password/ @douxc @iamjoel
|
||||||
|
/web/app/install/ @douxc @iamjoel
|
||||||
web/app/install/ @douxc @iamjoel
|
/web/app/init/ @douxc @iamjoel
|
||||||
web/app/init/ @douxc @iamjoel
|
/web/app/forgot-password/ @douxc @iamjoel
|
||||||
web/app/forgot-password/ @douxc @iamjoel
|
/web/app/account/ @douxc @iamjoel
|
||||||
web/app/account/ @douxc @iamjoel
|
|
||||||
|
|
||||||
# Frontend - Service Authentication
|
# Frontend - Service Authentication
|
||||||
web/service/base.ts @douxc @iamjoel
|
/web/service/base.ts @douxc @iamjoel
|
||||||
|
|
||||||
# Frontend - WebApp Authentication and Access Control
|
# Frontend - WebApp Authentication and Access Control
|
||||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
/web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
/web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||||
|
|
||||||
# Frontend - Explore Page
|
# Frontend - Explore Page
|
||||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
/web/app/components/explore/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Personal Settings
|
# Frontend - Personal Settings
|
||||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Analytics
|
# Frontend - Analytics
|
||||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
/web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Base Components
|
# Frontend - Base Components
|
||||||
web/app/components/base/ @iamjoel @zxhlyh
|
/web/app/components/base/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Utils and Hooks
|
# Frontend - Utils and Hooks
|
||||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||||
web/utils/time.ts @iamjoel @zxhlyh
|
/web/utils/time.ts @iamjoel @zxhlyh
|
||||||
web/utils/format.ts @iamjoel @zxhlyh
|
/web/utils/format.ts @iamjoel @zxhlyh
|
||||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
/web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Billing and Education
|
# Frontend - Billing and Education
|
||||||
web/app/components/billing/ @iamjoel @zxhlyh
|
/web/app/components/billing/ @iamjoel @zxhlyh
|
||||||
web/app/education-apply/ @iamjoel @zxhlyh
|
/web/app/education-apply/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Workspace
|
# Frontend - Workspace
|
||||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Docker
|
# Docker
|
||||||
docker/* @laipz8200
|
/docker/* @laipz8200
|
||||||
|
|
|
||||||
|
|
@ -22,12 +22,12 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
@ -57,7 +57,7 @@ jobs:
|
||||||
run: sh .github/workflows/expose_service_ports.sh
|
run: sh .github/workflows/expose_service_ports.sh
|
||||||
|
|
||||||
- name: Set up Sandbox
|
- name: Set up Sandbox
|
||||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
uses: hoverkraft-tech/compose-action@v2
|
||||||
with:
|
with:
|
||||||
compose-file: |
|
compose-file: |
|
||||||
docker/docker-compose.middleware.yaml
|
docker/docker-compose.middleware.yaml
|
||||||
|
|
|
||||||
|
|
@ -12,12 +12,28 @@ jobs:
|
||||||
if: github.repository == 'langgenius/dify'
|
if: github.repository == 'langgenius/dify'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Check Docker Compose inputs
|
||||||
|
id: docker-compose-changes
|
||||||
|
uses: tj-actions/changed-files@v46
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
docker/generate_docker_compose
|
||||||
|
docker/.env.example
|
||||||
|
docker/docker-compose-template.yaml
|
||||||
|
docker/docker-compose.yaml
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@v6
|
- uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Generate Docker Compose
|
||||||
|
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||||
|
run: |
|
||||||
|
cd docker
|
||||||
|
./generate_docker_compose
|
||||||
|
|
||||||
- run: |
|
- run: |
|
||||||
cd api
|
cd api
|
||||||
|
|
@ -68,25 +84,4 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
||||||
|
|
||||||
- name: Install pnpm
|
|
||||||
uses: pnpm/action-setup@v4
|
|
||||||
with:
|
|
||||||
package_json_file: web/package.json
|
|
||||||
run_install: false
|
|
||||||
|
|
||||||
- name: Setup NodeJS
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: 22
|
|
||||||
cache: pnpm
|
|
||||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Web dependencies
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: oxlint
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm exec oxlint --config .oxlintrc.json --fix .
|
|
||||||
|
|
||||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ jobs:
|
||||||
touch "/tmp/digests/${sanitized_digest}"
|
touch "/tmp/digests/${sanitized_digest}"
|
||||||
|
|
||||||
- name: Upload digest
|
- name: Upload digest
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||||
path: /tmp/digests/*
|
path: /tmp/digests/*
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,13 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|
@ -63,13 +63,13 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ jobs:
|
||||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
- uses: dorny/paths-filter@v3
|
- uses: dorny/paths-filter@v3
|
||||||
id: changes
|
id: changes
|
||||||
with:
|
with:
|
||||||
|
|
@ -38,6 +38,7 @@ jobs:
|
||||||
- '.github/workflows/api-tests.yml'
|
- '.github/workflows/api-tests.yml'
|
||||||
web:
|
web:
|
||||||
- 'web/**'
|
- 'web/**'
|
||||||
|
- '.github/workflows/web-tests.yml'
|
||||||
vdb:
|
vdb:
|
||||||
- 'api/core/rag/datasource/**'
|
- 'api/core/rag/datasource/**'
|
||||||
- 'docker/**'
|
- 'docker/**'
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,13 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v46
|
uses: tj-actions/changed-files@v47
|
||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
api/**
|
api/**
|
||||||
|
|
@ -33,7 +33,7 @@ jobs:
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: false
|
enable-cache: false
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|
@ -68,15 +68,17 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v46
|
uses: tj-actions/changed-files@v47
|
||||||
with:
|
with:
|
||||||
files: web/**
|
files: |
|
||||||
|
web/**
|
||||||
|
.github/workflows/style.yml
|
||||||
|
|
||||||
- name: Install pnpm
|
- name: Install pnpm
|
||||||
uses: pnpm/action-setup@v4
|
uses: pnpm/action-setup@v4
|
||||||
|
|
@ -85,7 +87,7 @@ jobs:
|
||||||
run_install: false
|
run_install: false
|
||||||
|
|
||||||
- name: Setup NodeJS
|
- name: Setup NodeJS
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
|
|
@ -108,50 +110,20 @@ jobs:
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run type-check:tsgo
|
run: pnpm run type-check:tsgo
|
||||||
|
|
||||||
docker-compose-template:
|
|
||||||
name: Docker Compose Template
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Check changed files
|
|
||||||
id: changed-files
|
|
||||||
uses: tj-actions/changed-files@v46
|
|
||||||
with:
|
|
||||||
files: |
|
|
||||||
docker/generate_docker_compose
|
|
||||||
docker/.env.example
|
|
||||||
docker/docker-compose-template.yaml
|
|
||||||
docker/docker-compose.yaml
|
|
||||||
|
|
||||||
- name: Generate Docker Compose
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
run: |
|
|
||||||
cd docker
|
|
||||||
./generate_docker_compose
|
|
||||||
|
|
||||||
- name: Check for changes
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
run: git diff --exit-code
|
|
||||||
|
|
||||||
superlinter:
|
superlinter:
|
||||||
name: SuperLinter
|
name: SuperLinter
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v46
|
uses: tj-actions/changed-files@v47
|
||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
**.sh
|
**.sh
|
||||||
|
|
|
||||||
|
|
@ -25,12 +25,12 @@ jobs:
|
||||||
working-directory: sdks/nodejs-client
|
working-directory: sdks/nodejs-client
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Use Node.js ${{ matrix.node-version }}
|
- name: Use Node.js ${{ matrix.node-version }}
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: ${{ matrix.node-version }}
|
node-version: ${{ matrix.node-version }}
|
||||||
cache: ''
|
cache: ''
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
name: Check i18n Files and Create PR
|
name: Translate i18n Files Based on English
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
paths:
|
paths:
|
||||||
- 'web/i18n/en-US/*.ts'
|
- 'web/i18n/en-US/*.json'
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
|
@ -18,7 +18,7 @@ jobs:
|
||||||
run:
|
run:
|
||||||
working-directory: web
|
working-directory: web
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
@ -28,13 +28,13 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
git fetch origin "${{ github.event.before }}" || true
|
git fetch origin "${{ github.event.before }}" || true
|
||||||
git fetch origin "${{ github.sha }}" || true
|
git fetch origin "${{ github.sha }}" || true
|
||||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
|
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
|
||||||
echo "Changed files: $changed_files"
|
echo "Changed files: $changed_files"
|
||||||
if [ -n "$changed_files" ]; then
|
if [ -n "$changed_files" ]; then
|
||||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||||
file_args=""
|
file_args=""
|
||||||
for file in $changed_files; do
|
for file in $changed_files; do
|
||||||
filename=$(basename "$file" .ts)
|
filename=$(basename "$file" .json)
|
||||||
file_args="$file_args --file $filename"
|
file_args="$file_args --file $filename"
|
||||||
done
|
done
|
||||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||||
|
|
@ -51,7 +51,7 @@ jobs:
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: 'lts/*'
|
node-version: 'lts/*'
|
||||||
cache: pnpm
|
cache: pnpm
|
||||||
|
|
@ -67,25 +67,19 @@ jobs:
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||||
|
|
||||||
- name: Generate i18n type definitions
|
|
||||||
if: env.FILES_CHANGED == 'true'
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm run gen:i18n-types
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
uses: peter-evans/create-pull-request@v6
|
uses: peter-evans/create-pull-request@v6
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
commit-message: 'chore(i18n): update translations based on en-US changes'
|
commit-message: 'chore(i18n): update translations based on en-US changes'
|
||||||
title: 'chore(i18n): translate i18n files and update type definitions'
|
title: 'chore(i18n): translate i18n files based on en-US changes'
|
||||||
body: |
|
body: |
|
||||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
This PR was automatically created to update i18n translation files based on changes in en-US locale.
|
||||||
|
|
||||||
**Triggered by:** ${{ github.sha }}
|
**Triggered by:** ${{ github.sha }}
|
||||||
|
|
||||||
**Changes included:**
|
**Changes included:**
|
||||||
- Updated translation files for all locales
|
- Updated translation files for all locales
|
||||||
- Regenerated TypeScript type definitions for type safety
|
|
||||||
branch: chore/automated-i18n-updates-${{ github.sha }}
|
branch: chore/automated-i18n-updates-${{ github.sha }}
|
||||||
delete-branch: true
|
delete-branch: true
|
||||||
|
|
|
||||||
|
|
@ -19,19 +19,19 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Free Disk Space
|
- name: Free Disk Space
|
||||||
uses: endersonmenezes/free-disk-space@v2
|
uses: endersonmenezes/free-disk-space@v3
|
||||||
with:
|
with:
|
||||||
remove_dotnet: true
|
remove_dotnet: true
|
||||||
remove_haskell: true
|
remove_haskell: true
|
||||||
remove_tool_cache: true
|
remove_tool_cache: true
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
|
@ -29,33 +29,17 @@ jobs:
|
||||||
run_install: false
|
run_install: false
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
cache: pnpm
|
cache: pnpm
|
||||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||||
|
|
||||||
- name: Restore Jest cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: web/.cache/jest
|
|
||||||
key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-jest-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Check i18n types synchronization
|
|
||||||
run: pnpm run check:i18n-types
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: pnpm test:coverage
|
||||||
pnpm exec jest \
|
|
||||||
--ci \
|
|
||||||
--maxWorkers=100% \
|
|
||||||
--coverage \
|
|
||||||
--passWithNoTests
|
|
||||||
|
|
||||||
- name: Coverage Summary
|
- name: Coverage Summary
|
||||||
if: always()
|
if: always()
|
||||||
|
|
@ -69,7 +53,7 @@ jobs:
|
||||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||||
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
@ -365,7 +349,7 @@ jobs:
|
||||||
.join(' | ')} |`;
|
.join(' | ')} |`;
|
||||||
|
|
||||||
console.log('');
|
console.log('');
|
||||||
console.log('<details><summary>Jest coverage table</summary>');
|
console.log('<details><summary>Vitest coverage table</summary>');
|
||||||
console.log('');
|
console.log('');
|
||||||
console.log(headerRow);
|
console.log(headerRow);
|
||||||
console.log(dividerRow);
|
console.log(dividerRow);
|
||||||
|
|
@ -376,7 +360,7 @@ jobs:
|
||||||
|
|
||||||
- name: Upload Coverage Artifact
|
- name: Upload Coverage Artifact
|
||||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: web-coverage-report
|
name: web-coverage-report
|
||||||
path: web/coverage
|
path: web/coverage
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,6 @@ pyrightconfig.json
|
||||||
.idea/'
|
.idea/'
|
||||||
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
web/.vscode/settings.json
|
|
||||||
|
|
||||||
# Intellij IDEA Files
|
# Intellij IDEA Files
|
||||||
.idea/*
|
.idea/*
|
||||||
|
|
@ -196,6 +195,7 @@ docker/nginx/ssl/*
|
||||||
!docker/nginx/ssl/.gitkeep
|
!docker/nginx/ssl/.gitkeep
|
||||||
docker/middleware.env
|
docker/middleware.env
|
||||||
docker/docker-compose.override.yaml
|
docker/docker-compose.override.yaml
|
||||||
|
docker/env-backup/*
|
||||||
|
|
||||||
sdks/python-client/build
|
sdks/python-client/build
|
||||||
sdks/python-client/dist
|
sdks/python-client/dist
|
||||||
|
|
@ -205,7 +205,6 @@ sdks/python-client/dify_client.egg-info
|
||||||
!.vscode/launch.json.template
|
!.vscode/launch.json.template
|
||||||
!.vscode/README.md
|
!.vscode/README.md
|
||||||
api/.vscode
|
api/.vscode
|
||||||
web/.vscode
|
|
||||||
# vscode Code History Extension
|
# vscode Code History Extension
|
||||||
.history
|
.history
|
||||||
|
|
||||||
|
|
@ -220,15 +219,6 @@ plugins.jsonl
|
||||||
# mise
|
# mise
|
||||||
mise.toml
|
mise.toml
|
||||||
|
|
||||||
# Next.js build output
|
|
||||||
.next/
|
|
||||||
|
|
||||||
# PWA generated files
|
|
||||||
web/public/sw.js
|
|
||||||
web/public/sw.js.map
|
|
||||||
web/public/workbox-*.js
|
|
||||||
web/public/workbox-*.js.map
|
|
||||||
web/public/fallback-*.js
|
|
||||||
|
|
||||||
# AI Assistant
|
# AI Assistant
|
||||||
.roo/
|
.roo/
|
||||||
|
|
|
||||||
34
.mcp.json
34
.mcp.json
|
|
@ -1,34 +0,0 @@
|
||||||
{
|
|
||||||
"mcpServers": {
|
|
||||||
"context7": {
|
|
||||||
"type": "http",
|
|
||||||
"url": "https://mcp.context7.com/mcp"
|
|
||||||
},
|
|
||||||
"sequential-thinking": {
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "npx",
|
|
||||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
|
||||||
"env": {}
|
|
||||||
},
|
|
||||||
"github": {
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "npx",
|
|
||||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
|
||||||
"env": {
|
|
||||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fetch": {
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "uvx",
|
|
||||||
"args": ["mcp-server-fetch"],
|
|
||||||
"env": {}
|
|
||||||
},
|
|
||||||
"playwright": {
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "npx",
|
|
||||||
"args": ["-y", "@playwright/mcp@latest"],
|
|
||||||
"env": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -116,6 +116,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||||
ALIYUN_OSS_REGION=your-region
|
ALIYUN_OSS_REGION=your-region
|
||||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||||
ALIYUN_OSS_PATH=your-path
|
ALIYUN_OSS_PATH=your-path
|
||||||
|
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||||
|
|
||||||
# Google Storage configuration
|
# Google Storage configuration
|
||||||
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
||||||
|
|
@ -127,12 +128,14 @@ TENCENT_COS_SECRET_KEY=your-secret-key
|
||||||
TENCENT_COS_SECRET_ID=your-secret-id
|
TENCENT_COS_SECRET_ID=your-secret-id
|
||||||
TENCENT_COS_REGION=your-region
|
TENCENT_COS_REGION=your-region
|
||||||
TENCENT_COS_SCHEME=your-scheme
|
TENCENT_COS_SCHEME=your-scheme
|
||||||
|
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
|
||||||
|
|
||||||
# Huawei OBS Storage Configuration
|
# Huawei OBS Storage Configuration
|
||||||
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||||
HUAWEI_OBS_SERVER=your-server-url
|
HUAWEI_OBS_SERVER=your-server-url
|
||||||
|
HUAWEI_OBS_PATH_STYLE=false
|
||||||
|
|
||||||
# Baidu OBS Storage Configuration
|
# Baidu OBS Storage Configuration
|
||||||
BAIDU_OBS_BUCKET_NAME=your-bucket-name
|
BAIDU_OBS_BUCKET_NAME=your-bucket-name
|
||||||
|
|
@ -690,7 +693,6 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
|
||||||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
||||||
# Maximum number of concurrent annotation import tasks per tenant
|
# Maximum number of concurrent annotation import tasks per tenant
|
||||||
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||||
|
|
||||||
# Sandbox expired records clean configuration
|
# Sandbox expired records clean configuration
|
||||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||||
|
|
|
||||||
|
|
@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
|
||||||
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ALIYUN_CLOUDBOX_ID: str | None = Field(
|
||||||
|
description="Cloudbox id for aliyun cloudbox service",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
|
||||||
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
HUAWEI_OBS_PATH_STYLE: bool = Field(
|
||||||
|
description="Flag to indicate whether to use path-style URLs for OBS requests",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings):
|
||||||
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
|
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
|
||||||
|
description="Tencent Cloud COS custom domain setting",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
import os
|
||||||
|
from email.message import Message
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
|
||||||
|
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
|
||||||
|
HTML_EXTENSIONS = frozenset({"html", "htm"})
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_mime_type(mime_type: str | None) -> str:
|
||||||
|
if not mime_type:
|
||||||
|
return ""
|
||||||
|
message = Message()
|
||||||
|
message["Content-Type"] = mime_type
|
||||||
|
return message.get_content_type().strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_html_extension(extension: str | None) -> bool:
|
||||||
|
if not extension:
|
||||||
|
return False
|
||||||
|
return extension.lstrip(".").lower() in HTML_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
|
||||||
|
normalized_mime_type = _normalize_mime_type(mime_type)
|
||||||
|
if normalized_mime_type in HTML_MIME_TYPES:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if _is_html_extension(extension):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
return _is_html_extension(os.path.splitext(filename)[1])
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def enforce_download_for_html(
|
||||||
|
response: Response,
|
||||||
|
*,
|
||||||
|
mime_type: str | None,
|
||||||
|
filename: str | None,
|
||||||
|
extension: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
if not is_html_content(mime_type, filename, extension):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
encoded_filename = quote(filename)
|
||||||
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
|
else:
|
||||||
|
response.headers["Content-Disposition"] = "attachment"
|
||||||
|
|
||||||
|
response.headers["Content-Type"] = "application/octet-stream"
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
return True
|
||||||
|
|
@ -7,9 +7,9 @@ from controllers.console import console_ns
|
||||||
from controllers.console.error import AlreadyActivateError
|
from controllers.console.error import AlreadyActivateError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
from libs.helper import EmailStr, timezone
|
||||||
from models import AccountStatus
|
from models import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService
|
from services.account_service import RegisterService
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
@ -93,7 +93,6 @@ class ActivateApi(Resource):
|
||||||
"ActivationResponse",
|
"ActivationResponse",
|
||||||
{
|
{
|
||||||
"result": fields.String(description="Operation result"),
|
"result": fields.String(description="Operation result"),
|
||||||
"data": fields.Raw(description="Login token data"),
|
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -117,6 +116,4 @@ class ActivateApi(Resource):
|
||||||
account.initialized_at = naive_utc_now()
|
account.initialized_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
return {"result": "success"}
|
||||||
|
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import base64
|
import base64
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionQuery(BaseModel):
|
class SubscriptionQuery(BaseModel):
|
||||||
plan: str = Field(..., description="Subscription plan")
|
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||||
interval: str = Field(..., description="Billing interval")
|
interval: Literal["month", "year"] = Field(..., description="Billing interval")
|
||||||
|
|
||||||
@field_validator("plan")
|
|
||||||
@classmethod
|
|
||||||
def validate_plan(cls, value: str) -> str:
|
|
||||||
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
|
|
||||||
raise ValueError("Invalid plan")
|
|
||||||
return value
|
|
||||||
|
|
||||||
@field_validator("interval")
|
|
||||||
@classmethod
|
|
||||||
def validate_interval(cls, value: str) -> str:
|
|
||||||
if value not in {"month", "year"}:
|
|
||||||
raise ValueError("Invalid interval")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class PartnerTenantsPayload(BaseModel):
|
class PartnerTenantsPayload(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -572,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||||
datasource_type=DatasourceType.NOTION,
|
datasource_type=DatasourceType.NOTION,
|
||||||
notion_info=NotionInfo.model_validate(
|
notion_info=NotionInfo.model_validate(
|
||||||
{
|
{
|
||||||
"credential_id": data_source_info["credential_id"],
|
"credential_id": data_source_info.get("credential_id"),
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
"notion_page_type": data_source_info["type"],
|
"notion_page_type": data_source_info["type"],
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import marshal_with
|
from flask_restx import marshal_with
|
||||||
|
|
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
|
|
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MessageListQuery(BaseModel):
|
class MessageListQuery(BaseModel):
|
||||||
conversation_id: UUID
|
conversation_id: UUIDStrOrEmpty
|
||||||
first_id: UUID | None = None
|
first_id: UUIDStrOrEmpty | None = None
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import fields, marshal_with
|
from flask_restx import fields, marshal_with
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
@ -10,19 +8,19 @@ from controllers.console import console_ns
|
||||||
from controllers.console.explore.error import NotCompletionAppError
|
from controllers.console.explore.error import NotCompletionAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageListQuery(BaseModel):
|
class SavedMessageListQuery(BaseModel):
|
||||||
last_id: UUID | None = None
|
last_id: UUIDStrOrEmpty | None = None
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageCreatePayload(BaseModel):
|
class SavedMessageCreatePayload(BaseModel):
|
||||||
message_id: UUID
|
message_id: UUIDStrOrEmpty
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,32 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from controllers.console import console_ns
|
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
|
||||||
from fields.api_based_extension_fields import api_based_extension_fields
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.api_based_extension import APIBasedExtension
|
from models.api_based_extension import APIBasedExtension
|
||||||
from services.api_based_extension_service import APIBasedExtensionService
|
from services.api_based_extension_service import APIBasedExtensionService
|
||||||
from services.code_based_extension_service import CodeBasedExtensionService
|
from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
|
|
||||||
|
from ..common.schema import register_schema_models
|
||||||
|
from . import console_ns
|
||||||
|
from .wraps import account_initialization_required, setup_required
|
||||||
|
|
||||||
|
|
||||||
|
class CodeBasedExtensionQuery(BaseModel):
|
||||||
|
module: str
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtensionPayload(BaseModel):
|
||||||
|
name: str = Field(description="Extension name")
|
||||||
|
api_endpoint: str = Field(description="API endpoint URL")
|
||||||
|
api_key: str = Field(description="API key for authentication")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, APIBasedExtensionPayload)
|
||||||
|
|
||||||
|
|
||||||
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
||||||
|
|
||||||
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
||||||
|
|
@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m
|
||||||
class CodeBasedExtensionAPI(Resource):
|
class CodeBasedExtensionAPI(Resource):
|
||||||
@console_ns.doc("get_code_based_extension")
|
@console_ns.doc("get_code_based_extension")
|
||||||
@console_ns.doc(description="Get code-based extension data by module name")
|
@console_ns.doc(description="Get code-based extension data by module name")
|
||||||
@console_ns.expect(
|
@console_ns.doc(params={"module": "Extension module name"})
|
||||||
console_ns.parser().add_argument(
|
|
||||||
"module", type=str, required=True, location="args", help="Extension module name"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Success",
|
"Success",
|
||||||
|
|
@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
|
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/api-based-extension")
|
@console_ns.route("/api-based-extension")
|
||||||
|
|
@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource):
|
||||||
|
|
||||||
@console_ns.doc("create_api_based_extension")
|
@console_ns.doc("create_api_based_extension")
|
||||||
@console_ns.doc(description="Create a new API-based extension")
|
@console_ns.doc(description="Create a new API-based extension")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CreateAPIBasedExtensionRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Extension name"),
|
|
||||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
|
||||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_model)
|
@marshal_with(api_based_extension_model)
|
||||||
def post(self):
|
def post(self):
|
||||||
args = console_ns.payload
|
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data = APIBasedExtension(
|
extension_data = APIBasedExtension(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=payload.name,
|
||||||
api_endpoint=args["api_endpoint"],
|
api_endpoint=payload.api_endpoint,
|
||||||
api_key=args["api_key"],
|
api_key=payload.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data)
|
return APIBasedExtensionService.save(extension_data)
|
||||||
|
|
@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||||
@console_ns.doc("update_api_based_extension")
|
@console_ns.doc("update_api_based_extension")
|
||||||
@console_ns.doc(description="Update API-based extension")
|
@console_ns.doc(description="Update API-based extension")
|
||||||
@console_ns.doc(params={"id": "Extension ID"})
|
@console_ns.doc(params={"id": "Extension ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"UpdateAPIBasedExtensionRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Extension name"),
|
|
||||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
|
||||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
args = console_ns.payload
|
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
extension_data_from_db.name = args["name"]
|
extension_data_from_db.name = payload.name
|
||||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
extension_data_from_db.api_endpoint = payload.api_endpoint
|
||||||
|
|
||||||
if args["api_key"] != HIDDEN_VALUE:
|
if payload.api_key != HIDDEN_VALUE:
|
||||||
extension_data_from_db.api_key = args["api_key"]
|
extension_data_from_db.api_key = payload.api_key
|
||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data_from_db)
|
return APIBasedExtensionService.save(extension_data_from_db)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
@ -10,10 +12,20 @@ from models import TenantAccountRole
|
||||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBalancingCredentialPayload(BaseModel):
|
||||||
|
model: str
|
||||||
|
model_type: ModelType
|
||||||
|
credentials: dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, LoadBalancingCredentialPayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||||
)
|
)
|
||||||
class LoadBalancingCredentialsValidateApi(Resource):
|
class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
|
|
||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
parser = (
|
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"model_type",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
nullable=False,
|
|
||||||
choices=[mt.value for mt in ModelType],
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# validate model load balancing credentials
|
# validate model load balancing credentials
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
|
|
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
model_load_balancing_service.validate_load_balancing_credentials(
|
model_load_balancing_service.validate_load_balancing_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args["model"],
|
model=payload.model,
|
||||||
model_type=args["model_type"],
|
model_type=payload.model_type,
|
||||||
credentials=args["credentials"],
|
credentials=payload.credentials,
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
result = False
|
result = False
|
||||||
|
|
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||||
)
|
)
|
||||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
|
|
||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
parser = (
|
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"model_type",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
nullable=False,
|
|
||||||
choices=[mt.value for mt in ModelType],
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# validate model load balancing config credentials
|
# validate model load balancing config credentials
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
|
|
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
model_load_balancing_service.validate_load_balancing_credentials(
|
model_load_balancing_service.validate_load_balancing_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args["model"],
|
model=payload.model,
|
||||||
model_type=args["model_type"],
|
model_type=payload.model_type,
|
||||||
credentials=args["credentials"],
|
credentials=payload.credentials,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import io
|
import io
|
||||||
from typing import Literal
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from flask import request, send_file
|
from flask import request, send_file
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
|
|
@ -141,6 +142,15 @@ class ParserDynamicOptions(BaseModel):
|
||||||
provider_type: Literal["tool", "trigger"]
|
provider_type: Literal["tool", "trigger"]
|
||||||
|
|
||||||
|
|
||||||
|
class ParserDynamicOptionsWithCredentials(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
|
provider: str
|
||||||
|
action: str
|
||||||
|
parameter: str
|
||||||
|
credential_id: str
|
||||||
|
credentials: Mapping[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class PluginPermissionSettingsPayload(BaseModel):
|
class PluginPermissionSettingsPayload(BaseModel):
|
||||||
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
||||||
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
||||||
|
|
@ -183,6 +193,7 @@ reg(ParserGithubUpgrade)
|
||||||
reg(ParserUninstall)
|
reg(ParserUninstall)
|
||||||
reg(ParserPermissionChange)
|
reg(ParserPermissionChange)
|
||||||
reg(ParserDynamicOptions)
|
reg(ParserDynamicOptions)
|
||||||
|
reg(ParserDynamicOptionsWithCredentials)
|
||||||
reg(ParserPreferencesChange)
|
reg(ParserPreferencesChange)
|
||||||
reg(ParserExcludePlugin)
|
reg(ParserExcludePlugin)
|
||||||
reg(ParserReadme)
|
reg(ParserReadme)
|
||||||
|
|
@ -657,6 +668,37 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||||
return jsonable_encoder({"options": options})
|
return jsonable_encoder({"options": options})
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options-with-credentials")
|
||||||
|
class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[ParserDynamicOptionsWithCredentials.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@is_admin_or_owner_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
"""Fetch dynamic options using credentials directly (for edit mode)."""
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
|
||||||
|
|
||||||
|
try:
|
||||||
|
options = PluginParameterService.get_dynamic_select_options_with_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
plugin_id=args.plugin_id,
|
||||||
|
provider=args.provider,
|
||||||
|
action=args.action,
|
||||||
|
parameter=args.parameter,
|
||||||
|
credential_id=args.credential_id,
|
||||||
|
credentials=args.credentials,
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder({"options": options})
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||||
class PluginChangePreferencesApi(Resource):
|
class PluginChangePreferencesApi(Resource):
|
||||||
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from flask import make_response, redirect, request, send_file
|
from flask import make_response, redirect, request, send_file
|
||||||
|
|
@ -17,6 +18,7 @@ from controllers.console.wraps import (
|
||||||
is_admin_or_owner_required,
|
is_admin_or_owner_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||||
|
|
@ -40,6 +42,8 @@ from services.tools.tools_manage_service import ToolCommonService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_valid_url(url: str) -> bool:
|
def is_valid_url(url: str) -> bool:
|
||||||
if not url:
|
if not url:
|
||||||
|
|
@ -945,8 +949,8 @@ class ToolProviderMCPApi(Resource):
|
||||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
|
|
||||||
# Create provider in transaction
|
# 1) Create provider in a short transaction (no network I/O inside)
|
||||||
with Session(db.engine) as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
result = service.create_provider(
|
result = service.create_provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
@ -962,7 +966,28 @@ class ToolProviderMCPApi(Resource):
|
||||||
authentication=authentication,
|
authentication=authentication,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||||
|
# Perform network I/O outside any DB session to avoid holding locks.
|
||||||
|
try:
|
||||||
|
reconnect = MCPToolManageService.reconnect_with_url(
|
||||||
|
server_url=args["server_url"],
|
||||||
|
headers=args.get("headers") or {},
|
||||||
|
timeout=configuration.timeout,
|
||||||
|
sse_read_timeout=configuration.sse_read_timeout,
|
||||||
|
)
|
||||||
|
# Update just-created provider with authed/tools in a new short transaction
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
|
||||||
|
db_provider.authed = reconnect.authed
|
||||||
|
db_provider.tools = reconnect.tools
|
||||||
|
|
||||||
|
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
|
||||||
|
except Exception:
|
||||||
|
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||||
|
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||||
|
|
||||||
|
# Final cache invalidation to ensure list views are up to date
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
|
|
||||||
return jsonable_encoder(result)
|
return jsonable_encoder(result)
|
||||||
|
|
@ -1081,6 +1106,8 @@ class ToolMCPAuthApi(Resource):
|
||||||
credentials=provider_entity.credentials,
|
credentials=provider_entity.credentials,
|
||||||
authed=True,
|
authed=True,
|
||||||
)
|
)
|
||||||
|
# Invalidate cache after updating credentials
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
except MCPAuthError as e:
|
except MCPAuthError as e:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1094,16 +1121,22 @@ class ToolMCPAuthApi(Resource):
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
response = service.execute_auth_actions(auth_result)
|
response = service.execute_auth_actions(auth_result)
|
||||||
|
# Invalidate cache after auth actions may have updated provider state
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
return response
|
return response
|
||||||
except MCPRefreshTokenError as e:
|
except MCPRefreshTokenError as e:
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
# Invalidate cache after clearing credentials
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||||
except (MCPError, ValueError) as e:
|
except (MCPError, ValueError) as e:
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
# Invalidate cache after clearing credentials
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask import make_response, redirect, request
|
from flask import make_response, redirect, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import BadRequest, Forbidden
|
from werkzeug.exceptions import BadRequest, Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||||
from controllers.web.error import NotFoundError
|
from controllers.web.error import NotFoundError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
|
@ -32,6 +36,32 @@ from ..wraps import (
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionUpdateRequest(BaseModel):
|
||||||
|
"""Request payload for updating a trigger subscription"""
|
||||||
|
|
||||||
|
name: str | None = Field(default=None, description="The name for the subscription")
|
||||||
|
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
|
||||||
|
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
|
||||||
|
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionVerifyRequest(BaseModel):
|
||||||
|
"""Request payload for verifying subscription credentials."""
|
||||||
|
|
||||||
|
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
TriggerSubscriptionUpdateRequest.__name__,
|
||||||
|
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
TriggerSubscriptionVerifyRequest.__name__,
|
||||||
|
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
|
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||||
class TriggerProviderIconApi(Resource):
|
class TriggerProviderIconApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -155,16 +185,16 @@ parser_api = (
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
|
||||||
)
|
)
|
||||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
|
||||||
@console_ns.expect(parser_api)
|
@console_ns.expect(parser_api)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider, subscription_builder_id):
|
def post(self, provider, subscription_builder_id):
|
||||||
"""Verify a subscription instance for a trigger provider"""
|
"""Verify and update a subscription instance for a trigger provider"""
|
||||||
user = current_user
|
user = current_user
|
||||||
assert user.current_tenant_id is not None
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
|
@ -289,6 +319,83 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||||
raise ValueError(str(e)) from e
|
raise ValueError(str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
|
||||||
|
)
|
||||||
|
class TriggerSubscriptionUpdateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@edit_permission_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, subscription_id: str):
|
||||||
|
"""Update a subscription instance"""
|
||||||
|
user = current_user
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||||
|
|
||||||
|
subscription = TriggerProviderService.get_subscription_by_id(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise NotFoundError(f"Subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID(subscription.provider_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# rename only
|
||||||
|
if (
|
||||||
|
args.name is not None
|
||||||
|
and args.credentials is None
|
||||||
|
and args.parameters is None
|
||||||
|
and args.properties is None
|
||||||
|
):
|
||||||
|
TriggerProviderService.update_trigger_subscription(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
name=args.name,
|
||||||
|
)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
# rebuild for create automatically by the provider
|
||||||
|
match subscription.credential_type:
|
||||||
|
case CredentialType.UNAUTHORIZED:
|
||||||
|
TriggerProviderService.update_trigger_subscription(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
name=args.name,
|
||||||
|
properties=args.properties,
|
||||||
|
)
|
||||||
|
return 200
|
||||||
|
case CredentialType.API_KEY | CredentialType.OAUTH2:
|
||||||
|
if args.credentials:
|
||||||
|
new_credentials: dict[str, Any] = {
|
||||||
|
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in args.credentials.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_credentials = subscription.credentials
|
||||||
|
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
name=args.name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
credentials=new_credentials,
|
||||||
|
parameters=args.parameters or subscription.parameters,
|
||||||
|
)
|
||||||
|
return 200
|
||||||
|
case _:
|
||||||
|
raise BadRequest("Invalid credential type")
|
||||||
|
except ValueError as e:
|
||||||
|
raise BadRequest(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error updating subscription", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||||
)
|
)
|
||||||
|
|
@ -576,3 +683,38 @@ class TriggerOAuthClientManageApi(Resource):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error removing OAuth client", exc_info=e)
|
logger.exception("Error removing OAuth client", exc_info=e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
|
||||||
|
)
|
||||||
|
class TriggerSubscriptionVerifyApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@edit_permission_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider, subscription_id):
|
||||||
|
"""Verify credentials for an existing subscription (edit mode only)"""
|
||||||
|
user = current_user
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
|
||||||
|
console_ns.payload
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = TriggerProviderService.verify_subscription_credentials(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
provider_id=TriggerProviderID(provider),
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
credentials=verify_request.credentials,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("Credential verification failed", exc_info=e)
|
||||||
|
raise BadRequest(str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error verifying subscription credentials", exc_info=e)
|
||||||
|
raise BadRequest(str(e)) from e
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.common.errors import UnsupportedFileTypeError
|
from controllers.common.errors import UnsupportedFileTypeError
|
||||||
|
from controllers.common.file_response import enforce_download_for_html
|
||||||
from controllers.files import files_ns
|
from controllers.files import files_ns
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
@ -138,6 +139,13 @@ class FilePreviewApi(Resource):
|
||||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
response.headers["Content-Type"] = "application/octet-stream"
|
response.headers["Content-Type"] = "application/octet-stream"
|
||||||
|
|
||||||
|
enforce_download_for_html(
|
||||||
|
response,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension=upload_file.extension,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
from controllers.common.errors import UnsupportedFileTypeError
|
from controllers.common.errors import UnsupportedFileTypeError
|
||||||
|
from controllers.common.file_response import enforce_download_for_html
|
||||||
from controllers.files import files_ns
|
from controllers.files import files_ns
|
||||||
from core.tools.signature import verify_tool_file_signature
|
from core.tools.signature import verify_tool_file_signature
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
@ -78,4 +79,11 @@ class ToolFileApi(Resource):
|
||||||
encoded_filename = quote(tool_file.name)
|
encoded_filename = quote(tool_file.name)
|
||||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
|
|
||||||
|
enforce_download_for_html(
|
||||||
|
response,
|
||||||
|
mime_type=tool_file.mimetype,
|
||||||
|
filename=tool_file.name,
|
||||||
|
extension=extension,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from flask_restx._http import HTTPStatus
|
from flask_restx._http import HTTPStatus
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
|
|
@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel):
|
||||||
class ConversationVariablesQuery(BaseModel):
|
class ConversationVariablesQuery(BaseModel):
|
||||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||||
|
variable_name: str | None = Field(
|
||||||
|
default=None, description="Filter variables by name", min_length=1, max_length=255
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("variable_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_variable_name(cls, v: str | None) -> str | None:
|
||||||
|
"""
|
||||||
|
Validate variable_name to prevent injection attacks.
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Only allow safe characters: alphanumeric, underscore, hyphen, period
|
||||||
|
if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
|
||||||
|
raise ValueError(
|
||||||
|
"Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prevent SQL injection patterns
|
||||||
|
dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
|
||||||
|
for pattern in dangerous_patterns:
|
||||||
|
if pattern in v.lower():
|
||||||
|
raise ValueError(f"Variable name contains invalid characters: {pattern}")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariableUpdatePayload(BaseModel):
|
class ConversationVariableUpdatePayload(BaseModel):
|
||||||
|
|
@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.get_conversational_variable(
|
return ConversationService.get_conversational_variable(
|
||||||
app_model, conversation_id, end_user, query_args.limit, last_id
|
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||||
)
|
)
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from flask import Response, request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from controllers.common.file_response import enforce_download_for_html
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import (
|
from controllers.service_api.app.error import (
|
||||||
|
|
@ -183,6 +184,13 @@ class FilePreviewApi(Resource):
|
||||||
# Override content-type for downloads to force download
|
# Override content-type for downloads to force download
|
||||||
response.headers["Content-Type"] = "application/octet-stream"
|
response.headers["Content-Type"] = "application/octet-stream"
|
||||||
|
|
||||||
|
enforce_download_for_html(
|
||||||
|
response,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension=upload_file.extension,
|
||||||
|
)
|
||||||
|
|
||||||
# Add caching headers for performance
|
# Add caching headers for performance
|
||||||
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
|
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
|
||||||
from controllers.service_api.wraps import (
|
from controllers.service_api.wraps import (
|
||||||
DatasetApiResource,
|
DatasetApiResource,
|
||||||
cloud_edition_billing_rate_limit_check,
|
cloud_edition_billing_rate_limit_check,
|
||||||
validate_dataset_token,
|
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
|
|
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||||
401: "Unauthorized - invalid API token",
|
401: "Unauthorized - invalid API token",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
|
||||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||||
def get(self, _, dataset_id):
|
def get(self, _):
|
||||||
"""Get all knowledge type tags."""
|
"""Get all knowledge type tags."""
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
cid = current_user.current_tenant_id
|
cid = current_user.current_tenant_id
|
||||||
|
|
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||||
@validate_dataset_token
|
def post(self, _):
|
||||||
def post(self, _, dataset_id):
|
|
||||||
"""Add a knowledge type tag."""
|
"""Add a knowledge type tag."""
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
|
|
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||||
@validate_dataset_token
|
def patch(self, _):
|
||||||
def patch(self, _, dataset_id):
|
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||||
403: "Forbidden - insufficient permissions",
|
403: "Forbidden - insufficient permissions",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def delete(self, _, dataset_id):
|
def delete(self, _):
|
||||||
"""Delete a knowledge type tag."""
|
"""Delete a knowledge type tag."""
|
||||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||||
TagService.delete_tag(payload.tag_id)
|
TagService.delete_tag(payload.tag_id)
|
||||||
|
|
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||||
403: "Forbidden - insufficient permissions",
|
403: "Forbidden - insufficient permissions",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
def post(self, _):
|
||||||
def post(self, _, dataset_id):
|
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
|
|
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||||
403: "Forbidden - insufficient permissions",
|
403: "Forbidden - insufficient permissions",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
def post(self, _):
|
||||||
def post(self, _, dataset_id):
|
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
|
|
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||||
401: "Unauthorized - invalid API token",
|
401: "Unauthorized - invalid API token",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
|
||||||
def get(self, _, *args, **kwargs):
|
def get(self, _, *args, **kwargs):
|
||||||
"""Get all knowledge type tags."""
|
"""Get all knowledge type tags."""
|
||||||
dataset_id = kwargs.get("dataset_id")
|
dataset_id = kwargs.get("dataset_id")
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
from constants import HEADER_NAME_APP_CODE
|
from constants import HEADER_NAME_APP_CODE
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.web import web_ns
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web.error import AppUnavailableError
|
|
||||||
from controllers.web.wraps import WebApiResource
|
|
||||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.token import extract_webapp_passport
|
from libs.token import extract_webapp_passport
|
||||||
|
|
@ -18,9 +17,23 @@ from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
|
from . import web_ns
|
||||||
|
from .error import AppUnavailableError
|
||||||
|
from .wraps import WebApiResource
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AppAccessModeQuery(BaseModel):
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
app_id: str | None = Field(default=None, alias="appId", description="Application ID")
|
||||||
|
app_code: str | None = Field(default=None, alias="appCode", description="Application code")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(web_ns, AppAccessModeQuery)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/parameters")
|
@web_ns.route("/parameters")
|
||||||
class AppParameterApi(WebApiResource):
|
class AppParameterApi(WebApiResource):
|
||||||
"""Resource for app variables."""
|
"""Resource for app variables."""
|
||||||
|
|
@ -96,21 +109,16 @@ class AppAccessMode(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = (
|
raw_args = request.args.to_dict()
|
||||||
reqparse.RequestParser()
|
args = AppAccessModeQuery.model_validate(raw_args)
|
||||||
.add_argument("appId", type=str, required=False, location="args")
|
|
||||||
.add_argument("appCode", type=str, required=False, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if not features.webapp_auth.enabled:
|
if not features.webapp_auth.enabled:
|
||||||
return {"accessMode": "public"}
|
return {"accessMode": "public"}
|
||||||
|
|
||||||
app_id = args.get("appId")
|
app_id = args.app_id
|
||||||
if args.get("appCode"):
|
if args.app_code:
|
||||||
app_code = args["appCode"]
|
app_id = AppService.get_app_id_by_code(args.app_code)
|
||||||
app_id = AppService.get_app_id_by_code(app_code)
|
|
||||||
|
|
||||||
if not app_id:
|
if not app_id:
|
||||||
raise ValueError("appId or appCode must be provided")
|
raise ValueError("appId or appCode must be provided")
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@ import base64
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.auth.error import (
|
from controllers.console.auth.error import (
|
||||||
AuthenticationFailedError,
|
AuthenticationFailedError,
|
||||||
EmailCodeError,
|
EmailCodeError,
|
||||||
|
|
@ -18,14 +20,40 @@ from controllers.console.error import EmailSendIpLimitError
|
||||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models import Account
|
from models import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordSendPayload(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
language: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordCheckPayload(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
code: str
|
||||||
|
token: str = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordResetPayload(BaseModel):
|
||||||
|
token: str = Field(min_length=1)
|
||||||
|
new_password: str
|
||||||
|
password_confirm: str
|
||||||
|
|
||||||
|
@field_validator("new_password", "password_confirm")
|
||||||
|
@classmethod
|
||||||
|
def validate_password(cls, value: str) -> str:
|
||||||
|
return valid_password(value)
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/forgot-password")
|
@web_ns.route("/forgot-password")
|
||||||
class ForgotPasswordSendEmailApi(Resource):
|
class ForgotPasswordSendEmailApi(Resource):
|
||||||
|
@web_ns.expect(web_ns.models[ForgotPasswordSendPayload.__name__])
|
||||||
@only_edition_enterprise
|
@only_edition_enterprise
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
|
@ -40,35 +68,31 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if payload.language == "zh-Hans":
|
||||||
language = "zh-Hans"
|
language = "zh-Hans"
|
||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
|
||||||
token = None
|
token = None
|
||||||
if account is None:
|
if account is None:
|
||||||
raise AuthenticationFailedError()
|
raise AuthenticationFailedError()
|
||||||
else:
|
else:
|
||||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
|
||||||
|
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/forgot-password/validity")
|
@web_ns.route("/forgot-password/validity")
|
||||||
class ForgotPasswordCheckApi(Resource):
|
class ForgotPasswordCheckApi(Resource):
|
||||||
|
@web_ns.expect(web_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||||
@only_edition_enterprise
|
@only_edition_enterprise
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
|
@ -78,45 +102,40 @@ class ForgotPasswordCheckApi(Resource):
|
||||||
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
|
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = payload.email
|
||||||
|
|
||||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
|
||||||
if is_forgot_password_error_rate_limit:
|
if is_forgot_password_error_rate_limit:
|
||||||
raise EmailPasswordResetLimitError()
|
raise EmailPasswordResetLimitError()
|
||||||
|
|
||||||
token_data = AccountService.get_reset_password_data(args["token"])
|
token_data = AccountService.get_reset_password_data(payload.token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if user_email != token_data.get("email"):
|
if user_email != token_data.get("email"):
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if args["code"] != token_data.get("code"):
|
if payload.code != token_data.get("code"):
|
||||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
AccountService.add_forgot_password_error_rate_limit(payload.email)
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
# Verified, revoke the first token
|
# Verified, revoke the first token
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(payload.token)
|
||||||
|
|
||||||
# Refresh token data by generating a new token
|
# Refresh token data by generating a new token
|
||||||
_, new_token = AccountService.generate_reset_password_token(
|
_, new_token = AccountService.generate_reset_password_token(
|
||||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
user_email, code=payload.code, additional_data={"phase": "reset"}
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
AccountService.reset_forgot_password_error_rate_limit(payload.email)
|
||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/forgot-password/resets")
|
@web_ns.route("/forgot-password/resets")
|
||||||
class ForgotPasswordResetApi(Resource):
|
class ForgotPasswordResetApi(Resource):
|
||||||
|
@web_ns.expect(web_ns.models[ForgotPasswordResetPayload.__name__])
|
||||||
@only_edition_enterprise
|
@only_edition_enterprise
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
|
@ -131,20 +150,14 @@ class ForgotPasswordResetApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
if args["new_password"] != args["password_confirm"]:
|
if payload.new_password != payload.password_confirm:
|
||||||
raise PasswordMismatchError()
|
raise PasswordMismatchError()
|
||||||
|
|
||||||
# Validate token and get reset data
|
# Validate token and get reset data
|
||||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
reset_data = AccountService.get_reset_password_data(payload.token)
|
||||||
if not reset_data:
|
if not reset_data:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
# Must use token in reset phase
|
# Must use token in reset phase
|
||||||
|
|
@ -152,11 +165,11 @@ class ForgotPasswordResetApi(Resource):
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
# Revoke token to prevent reuse
|
# Revoke token to prevent reuse
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(payload.token)
|
||||||
|
|
||||||
# Generate secure salt and hash password
|
# Generate secure salt and hash password
|
||||||
salt = secrets.token_bytes(16)
|
salt = secrets.token_bytes(16)
|
||||||
password_hashed = hash_password(args["new_password"], salt)
|
password_hashed = hash_password(payload.new_password, salt)
|
||||||
|
|
||||||
email = reset_data.get("email", "")
|
email = reset_data.get("email", "")
|
||||||
|
|
||||||
|
|
@ -170,7 +183,7 @@ class ForgotPasswordResetApi(Resource):
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
def _update_existing_account(self, account: Account, password_hashed, salt, session):
|
||||||
# Update existing account credentials
|
# Update existing account credentials
|
||||||
account.password = base64.b64encode(password_hashed).decode()
|
account.password = base64.b64encode(password_hashed).decode()
|
||||||
account.password_salt = base64.b64encode(salt).decode()
|
account.password_salt = base64.b64encode(salt).decode()
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask_restx import fields, marshal_with, reqparse
|
from flask import request
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx import fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
AppMoreLikeThisDisabledError,
|
AppMoreLikeThisDisabledError,
|
||||||
|
|
@ -38,6 +41,33 @@ from services.message_service import MessageService
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageListQuery(BaseModel):
|
||||||
|
conversation_id: str = Field(description="Conversation UUID")
|
||||||
|
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||||
|
|
||||||
|
@field_validator("conversation_id", "first_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_uuid(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFeedbackPayload(BaseModel):
|
||||||
|
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||||
|
content: str | None = Field(default=None, description="Feedback content")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageMoreLikeThisQuery(BaseModel):
|
||||||
|
response_mode: Literal["blocking", "streaming"] = Field(
|
||||||
|
description="Response mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/messages")
|
@web_ns.route("/messages")
|
||||||
class MessageListApi(WebApiResource):
|
class MessageListApi(WebApiResource):
|
||||||
message_fields = {
|
message_fields = {
|
||||||
|
|
@ -68,7 +98,11 @@ class MessageListApi(WebApiResource):
|
||||||
@web_ns.doc(
|
@web_ns.doc(
|
||||||
params={
|
params={
|
||||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
|
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
|
||||||
"first_id": {"description": "First message ID for pagination", "type": "string", "required": False},
|
"first_id": {
|
||||||
|
"description": "First message ID for pagination",
|
||||||
|
"type": "string",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
"limit": {
|
"limit": {
|
||||||
"description": "Number of messages to return (1-100)",
|
"description": "Number of messages to return (1-100)",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
|
@ -93,17 +127,12 @@ class MessageListApi(WebApiResource):
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = (
|
raw_args = request.args.to_dict()
|
||||||
reqparse.RequestParser()
|
query = MessageListQuery.model_validate(raw_args)
|
||||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
|
||||||
.add_argument("first_id", type=uuid_value, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return MessageService.pagination_by_first_id(
|
return MessageService.pagination_by_first_id(
|
||||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
app_model, end_user, query.conversation_id, query.first_id, query.limit
|
||||||
)
|
)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
@ -128,7 +157,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||||
"enum": ["like", "dislike"],
|
"enum": ["like", "dislike"],
|
||||||
"required": False,
|
"required": False,
|
||||||
},
|
},
|
||||||
"content": {"description": "Feedback content/comment", "type": "string", "required": False},
|
"content": {"description": "Feedback content", "type": "string", "required": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@web_ns.doc(
|
@web_ns.doc(
|
||||||
|
|
@ -145,20 +174,15 @@ class MessageFeedbackApi(WebApiResource):
|
||||||
def post(self, app_model, end_user, message_id):
|
def post(self, app_model, end_user, message_id):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = (
|
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
|
||||||
.add_argument("content", type=str, location="json", default=None)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
MessageService.create_feedback(
|
MessageService.create_feedback(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
user=end_user,
|
user=end_user,
|
||||||
rating=args.get("rating"),
|
rating=payload.rating,
|
||||||
content=args.get("content"),
|
content=payload.content,
|
||||||
)
|
)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
@ -170,17 +194,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||||
class MessageMoreLikeThisApi(WebApiResource):
|
class MessageMoreLikeThisApi(WebApiResource):
|
||||||
@web_ns.doc("Generate More Like This")
|
@web_ns.doc("Generate More Like This")
|
||||||
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
|
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
|
||||||
@web_ns.doc(
|
@web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__])
|
||||||
params={
|
|
||||||
"message_id": {"description": "Message UUID", "type": "string", "required": True},
|
|
||||||
"response_mode": {
|
|
||||||
"description": "Response mode",
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["blocking", "streaming"],
|
|
||||||
"required": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@web_ns.doc(
|
@web_ns.doc(
|
||||||
responses={
|
responses={
|
||||||
200: "Success",
|
200: "Success",
|
||||||
|
|
@ -197,12 +211,10 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument(
|
raw_args = request.args.to_dict()
|
||||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
query = MessageMoreLikeThisQuery.model_validate(raw_args)
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = query.response_mode == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_restx import marshal_with, reqparse
|
from flask_restx import marshal_with
|
||||||
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.common import helpers
|
from controllers.common import helpers
|
||||||
|
|
@ -10,14 +11,23 @@ from controllers.common.errors import (
|
||||||
RemoteFileUploadError,
|
RemoteFileUploadError,
|
||||||
UnsupportedFileTypeError,
|
UnsupportedFileTypeError,
|
||||||
)
|
)
|
||||||
from controllers.web import web_ns
|
|
||||||
from controllers.web.wraps import WebApiResource
|
|
||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
from ..common.schema import register_schema_models
|
||||||
|
from . import web_ns
|
||||||
|
from .wraps import WebApiResource
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteFileUploadPayload(BaseModel):
|
||||||
|
url: HttpUrl = Field(description="Remote file URL")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(web_ns, RemoteFileUploadPayload)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/remote-files/<path:url>")
|
@web_ns.route("/remote-files/<path:url>")
|
||||||
class RemoteFileInfoApi(WebApiResource):
|
class RemoteFileInfoApi(WebApiResource):
|
||||||
|
|
@ -97,10 +107,8 @@ class RemoteFileUploadApi(WebApiResource):
|
||||||
FileTooLargeError: File exceeds size limit
|
FileTooLargeError: File exceeds size limit
|
||||||
UnsupportedFileTypeError: File type not supported
|
UnsupportedFileTypeError: File type not supported
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {})
|
||||||
args = parser.parse_args()
|
url = str(payload.url)
|
||||||
|
|
||||||
url = args["url"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = ssrf_proxy.head(url=url)
|
resp = ssrf_proxy.head(url=url)
|
||||||
|
|
|
||||||
|
|
@ -105,8 +105,9 @@ class BaseAppGenerator:
|
||||||
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
|
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
|
||||||
and not variable_entity.required
|
and not variable_entity.required
|
||||||
):
|
):
|
||||||
# Treat empty string (frontend default) or empty list as unset
|
# Treat empty string (frontend default) as unset
|
||||||
if not value and isinstance(value, (str, list)):
|
# For FILE_LIST, allow empty list [] to pass through
|
||||||
|
if isinstance(value, str) and not value:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if variable_entity.type in {
|
if variable_entity.type in {
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,7 @@ class AppQueueManager:
|
||||||
"""
|
"""
|
||||||
self._clear_task_belong_cache()
|
self._clear_task_belong_cache()
|
||||||
self._q.put(None)
|
self._q.put(None)
|
||||||
|
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
|
||||||
|
|
||||||
def _clear_task_belong_cache(self) -> None:
|
def _clear_task_belong_cache(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,14 @@
|
||||||
|
from collections.abc import Mapping
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||||
|
|
||||||
|
|
||||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||||
|
# Use separate placeholder for base64-encoded template to avoid confusion
|
||||||
|
_template_b64_placeholder: str = "{{template_b64}}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_response(cls, response: str):
|
def transform_response(cls, response: str):
|
||||||
"""
|
"""
|
||||||
|
|
@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||||
"""
|
"""
|
||||||
return {"result": cls.extract_result_str_from_response(response)}
|
return {"result": cls.extract_result_str_from_response(response)}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Override base class to use base64 encoding for template code.
|
||||||
|
This prevents issues with special characters (quotes, newlines) in templates
|
||||||
|
breaking the generated Python script. Fixes #26818.
|
||||||
|
"""
|
||||||
|
script = cls.get_runner_script()
|
||||||
|
# Encode template as base64 to safely embed any content including quotes
|
||||||
|
code_b64 = cls.serialize_code(code)
|
||||||
|
script = script.replace(cls._template_b64_placeholder, code_b64)
|
||||||
|
inputs_str = cls.serialize_inputs(inputs)
|
||||||
|
script = script.replace(cls._inputs_placeholder, inputs_str)
|
||||||
|
return script
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_runner_script(cls) -> str:
|
def get_runner_script(cls) -> str:
|
||||||
runner_script = dedent(f"""
|
runner_script = dedent(f"""
|
||||||
# declare main function
|
|
||||||
def main(**inputs):
|
|
||||||
import jinja2
|
import jinja2
|
||||||
template = jinja2.Template('''{cls._code_placeholder}''')
|
|
||||||
return template.render(**inputs)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
|
|
||||||
|
# declare main function
|
||||||
|
def main(**inputs):
|
||||||
|
# Decode base64-encoded template to handle special characters safely
|
||||||
|
template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
|
||||||
|
template = jinja2.Template(template_code)
|
||||||
|
return template.render(**inputs)
|
||||||
|
|
||||||
# decode and prepare input dict
|
# decode and prepare input dict
|
||||||
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
|
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,15 @@ class TemplateTransformer(ABC):
|
||||||
_inputs_placeholder: str = "{{inputs}}"
|
_inputs_placeholder: str = "{{inputs}}"
|
||||||
_result_tag: str = "<<RESULT>>"
|
_result_tag: str = "<<RESULT>>"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize_code(cls, code: str) -> str:
|
||||||
|
"""
|
||||||
|
Serialize template code to base64 to safely embed in generated script.
|
||||||
|
This prevents issues with special characters like quotes breaking the script.
|
||||||
|
"""
|
||||||
|
code_bytes = code.encode("utf-8")
|
||||||
|
return b64encode(code_bytes).decode("utf-8")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
|
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,22 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||||
|
"""
|
||||||
|
Extract the user-provided Host header from the headers dict.
|
||||||
|
|
||||||
|
This is needed because when using a forward proxy, httpx may override the Host header.
|
||||||
|
We preserve the user's explicit Host header to support virtual hosting and other use cases.
|
||||||
|
"""
|
||||||
|
if not headers:
|
||||||
|
return None
|
||||||
|
# Case-insensitive lookup for Host header
|
||||||
|
for key, value in headers.items():
|
||||||
|
if key.lower() == "host":
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
if "allow_redirects" in kwargs:
|
if "allow_redirects" in kwargs:
|
||||||
allow_redirects = kwargs.pop("allow_redirects")
|
allow_redirects = kwargs.pop("allow_redirects")
|
||||||
|
|
@ -90,10 +106,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||||
client = _get_ssrf_client(verify_option)
|
client = _get_ssrf_client(verify_option)
|
||||||
|
|
||||||
|
# Preserve user-provided Host header
|
||||||
|
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||||
|
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||||
|
headers = kwargs.get("headers", {})
|
||||||
|
user_provided_host = _get_user_provided_host_header(headers)
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
while retries <= max_retries:
|
while retries <= max_retries:
|
||||||
try:
|
try:
|
||||||
|
# Build the request manually to preserve the Host header
|
||||||
|
# httpx may override the Host header when using a proxy, so we use
|
||||||
|
# the request API to explicitly set headers before sending
|
||||||
|
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||||
|
if user_provided_host is not None:
|
||||||
|
headers["host"] = user_provided_host
|
||||||
|
kwargs["headers"] = headers
|
||||||
response = client.request(method=method, url=url, **kwargs)
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
|
|
||||||
# Check for SSRF protection by Squid proxy
|
# Check for SSRF protection by Squid proxy
|
||||||
if response.status_code in (401, 403):
|
if response.status_code in (401, 403):
|
||||||
# Check if this is a Squid SSRF rejection
|
# Check if this is a Squid SSRF rejection
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||||
from extensions.ext_redis import redis_client, redis_fallback
|
from extensions.ext_redis import redis_client, redis_fallback
|
||||||
|
|
@ -50,7 +50,9 @@ class ToolProviderListCache:
|
||||||
redis_client.delete(cache_key)
|
redis_client.delete(cache_key)
|
||||||
else:
|
else:
|
||||||
# Invalidate all caches for this tenant
|
# Invalidate all caches for this tenant
|
||||||
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
|
keys = ["builtin", "model", "api", "workflow", "mcp"]
|
||||||
keys = list(redis_client.scan_iter(pattern))
|
pipeline = redis_client.pipeline()
|
||||||
if keys:
|
for key in keys:
|
||||||
redis_client.delete(*keys)
|
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
|
||||||
|
pipeline.delete(cache_key)
|
||||||
|
pipeline.execute()
|
||||||
|
|
|
||||||
|
|
@ -396,7 +396,7 @@ class IndexingRunner:
|
||||||
datasource_type=DatasourceType.NOTION,
|
datasource_type=DatasourceType.NOTION,
|
||||||
notion_info=NotionInfo.model_validate(
|
notion_info=NotionInfo.model_validate(
|
||||||
{
|
{
|
||||||
"credential_id": data_source_info["credential_id"],
|
"credential_id": data_source_info.get("credential_id"),
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
"notion_page_type": data_source_info["type"],
|
"notion_page_type": data_source_info["type"],
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,7 @@ class SSETransport:
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.sse_read_timeout = sse_read_timeout
|
self.sse_read_timeout = sse_read_timeout
|
||||||
self.endpoint_url: str | None = None
|
self.endpoint_url: str | None = None
|
||||||
|
self.event_source: EventSource | None = None
|
||||||
|
|
||||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||||
"""Validate that the endpoint URL matches the connection origin.
|
"""Validate that the endpoint URL matches the connection origin.
|
||||||
|
|
@ -237,6 +238,9 @@ class SSETransport:
|
||||||
write_queue: WriteQueue = queue.Queue()
|
write_queue: WriteQueue = queue.Queue()
|
||||||
status_queue: StatusQueue = queue.Queue()
|
status_queue: StatusQueue = queue.Queue()
|
||||||
|
|
||||||
|
# Store event_source for graceful shutdown
|
||||||
|
self.event_source = event_source
|
||||||
|
|
||||||
# Start SSE reader thread
|
# Start SSE reader thread
|
||||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||||
|
|
||||||
|
|
@ -296,6 +300,13 @@ def sse_client(
|
||||||
logger.exception("Error connecting to SSE endpoint")
|
logger.exception("Error connecting to SSE endpoint")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# Close the SSE connection to unblock the reader thread
|
||||||
|
if transport.event_source is not None:
|
||||||
|
try:
|
||||||
|
transport.event_source.response.close()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Clean up queues
|
# Clean up queues
|
||||||
if read_queue:
|
if read_queue:
|
||||||
read_queue.put(None)
|
read_queue.put(None)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ and session management.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
import threading
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
@ -103,6 +104,9 @@ class StreamableHTTPTransport:
|
||||||
CONTENT_TYPE: JSON,
|
CONTENT_TYPE: JSON,
|
||||||
**self.headers,
|
**self.headers,
|
||||||
}
|
}
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
self._active_responses: list[httpx.Response] = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||||
"""Update headers with session ID if available."""
|
"""Update headers with session ID if available."""
|
||||||
|
|
@ -111,6 +115,30 @@ class StreamableHTTPTransport:
|
||||||
headers[MCP_SESSION_ID] = self.session_id
|
headers[MCP_SESSION_ID] = self.session_id
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
def _register_response(self, response: httpx.Response):
|
||||||
|
"""Register a response for cleanup on shutdown."""
|
||||||
|
with self._lock:
|
||||||
|
self._active_responses.append(response)
|
||||||
|
|
||||||
|
def _unregister_response(self, response: httpx.Response):
|
||||||
|
"""Unregister a response after it's closed."""
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
self._active_responses.remove(response)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.debug("Ignoring error during response unregister: %s", e)
|
||||||
|
|
||||||
|
def close_active_responses(self):
|
||||||
|
"""Close all active SSE connections to unblock threads."""
|
||||||
|
with self._lock:
|
||||||
|
responses_to_close = list(self._active_responses)
|
||||||
|
self._active_responses.clear()
|
||||||
|
for response in responses_to_close:
|
||||||
|
try:
|
||||||
|
response.close()
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.debug("Ignoring error during active response close: %s", e)
|
||||||
|
|
||||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||||
"""Check if the message is an initialization request."""
|
"""Check if the message is an initialization request."""
|
||||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||||
|
|
@ -195,10 +223,20 @@ class StreamableHTTPTransport:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("GET SSE connection established")
|
logger.debug("GET SSE connection established")
|
||||||
|
|
||||||
|
# Register response for cleanup
|
||||||
|
self._register_response(event_source.response)
|
||||||
|
|
||||||
|
try:
|
||||||
for sse in event_source.iter_sse():
|
for sse in event_source.iter_sse():
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
logger.debug("GET stream received stop signal")
|
||||||
|
break
|
||||||
self._handle_sse_event(sse, server_to_client_queue)
|
self._handle_sse_event(sse, server_to_client_queue)
|
||||||
|
finally:
|
||||||
|
self._unregister_response(event_source.response)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
if not self.stop_event.is_set():
|
||||||
logger.debug("GET stream error (non-fatal): %s", exc)
|
logger.debug("GET stream error (non-fatal): %s", exc)
|
||||||
|
|
||||||
def _handle_resumption_request(self, ctx: RequestContext):
|
def _handle_resumption_request(self, ctx: RequestContext):
|
||||||
|
|
@ -224,7 +262,14 @@ class StreamableHTTPTransport:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("Resumption GET SSE connection established")
|
logger.debug("Resumption GET SSE connection established")
|
||||||
|
|
||||||
|
# Register response for cleanup
|
||||||
|
self._register_response(event_source.response)
|
||||||
|
|
||||||
|
try:
|
||||||
for sse in event_source.iter_sse():
|
for sse in event_source.iter_sse():
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
logger.debug("Resumption stream received stop signal")
|
||||||
|
break
|
||||||
is_complete = self._handle_sse_event(
|
is_complete = self._handle_sse_event(
|
||||||
sse,
|
sse,
|
||||||
ctx.server_to_client_queue,
|
ctx.server_to_client_queue,
|
||||||
|
|
@ -233,6 +278,8 @@ class StreamableHTTPTransport:
|
||||||
)
|
)
|
||||||
if is_complete:
|
if is_complete:
|
||||||
break
|
break
|
||||||
|
finally:
|
||||||
|
self._unregister_response(event_source.response)
|
||||||
|
|
||||||
def _handle_post_request(self, ctx: RequestContext):
|
def _handle_post_request(self, ctx: RequestContext):
|
||||||
"""Handle a POST request with response processing."""
|
"""Handle a POST request with response processing."""
|
||||||
|
|
@ -266,6 +313,9 @@ class StreamableHTTPTransport:
|
||||||
if is_initialization:
|
if is_initialization:
|
||||||
self._maybe_extract_session_id_from_response(response)
|
self._maybe_extract_session_id_from_response(response)
|
||||||
|
|
||||||
|
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
|
||||||
|
# The server MUST NOT send a response to notifications.
|
||||||
|
if isinstance(message.root, JSONRPCRequest):
|
||||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||||
|
|
||||||
if content_type.startswith(JSON):
|
if content_type.startswith(JSON):
|
||||||
|
|
@ -295,8 +345,15 @@ class StreamableHTTPTransport:
|
||||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
||||||
"""Handle SSE response from the server."""
|
"""Handle SSE response from the server."""
|
||||||
try:
|
try:
|
||||||
|
# Register response for cleanup
|
||||||
|
self._register_response(response)
|
||||||
|
|
||||||
event_source = EventSource(response)
|
event_source = EventSource(response)
|
||||||
|
try:
|
||||||
for sse in event_source.iter_sse():
|
for sse in event_source.iter_sse():
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
logger.debug("SSE response stream received stop signal")
|
||||||
|
break
|
||||||
is_complete = self._handle_sse_event(
|
is_complete = self._handle_sse_event(
|
||||||
sse,
|
sse,
|
||||||
ctx.server_to_client_queue,
|
ctx.server_to_client_queue,
|
||||||
|
|
@ -304,7 +361,10 @@ class StreamableHTTPTransport:
|
||||||
)
|
)
|
||||||
if is_complete:
|
if is_complete:
|
||||||
break
|
break
|
||||||
|
finally:
|
||||||
|
self._unregister_response(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if not self.stop_event.is_set():
|
||||||
ctx.server_to_client_queue.put(e)
|
ctx.server_to_client_queue.put(e)
|
||||||
|
|
||||||
def _handle_unexpected_content_type(
|
def _handle_unexpected_content_type(
|
||||||
|
|
@ -345,6 +405,11 @@ class StreamableHTTPTransport:
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
# Check if we should stop
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
logger.debug("Post writer received stop signal")
|
||||||
|
break
|
||||||
|
|
||||||
# Read message from client queue with timeout to check stop_event periodically
|
# Read message from client queue with timeout to check stop_event periodically
|
||||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||||
if session_message is None:
|
if session_message is None:
|
||||||
|
|
@ -381,6 +446,7 @@ class StreamableHTTPTransport:
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
if not self.stop_event.is_set():
|
||||||
server_to_client_queue.put(exc)
|
server_to_client_queue.put(exc)
|
||||||
|
|
||||||
def terminate_session(self, client: httpx.Client):
|
def terminate_session(self, client: httpx.Client):
|
||||||
|
|
@ -465,6 +531,12 @@ def streamablehttp_client(
|
||||||
transport.get_session_id,
|
transport.get_session_id,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
# Set stop event to signal all threads to stop
|
||||||
|
transport.stop_event.set()
|
||||||
|
|
||||||
|
# Close all active SSE connections to unblock threads
|
||||||
|
transport.close_active_responses()
|
||||||
|
|
||||||
if transport.session_id and terminate_on_close:
|
if transport.session_id and terminate_on_close:
|
||||||
transport.terminate_session(client)
|
transport.terminate_session(client)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_
|
||||||
generate dotted_order for langsmith
|
generate dotted_order for langsmith
|
||||||
"""
|
"""
|
||||||
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
||||||
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
|
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f") + "Z"
|
||||||
current_segment = f"{timestamp}{run_id}"
|
current_segment = f"{timestamp}{run_id}"
|
||||||
|
|
||||||
if parent_dotted_order is None:
|
if parent_dotted_order is None:
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
|
||||||
auto_generate: PluginParameterAutoGenerate | None = None
|
auto_generate: PluginParameterAutoGenerate | None = None
|
||||||
template: PluginParameterTemplate | None = None
|
template: PluginParameterTemplate | None = None
|
||||||
required: bool = False
|
required: bool = False
|
||||||
default: Union[float, int, str, bool] | None = None
|
default: Union[float, int, str, bool, list, dict] | None = None
|
||||||
min: Union[float, int] | None = None
|
min: Union[float, int] | None = None
|
||||||
max: Union[float, int] | None = None
|
max: Union[float, int] | None = None
|
||||||
precision: int | None = None
|
precision: int | None = None
|
||||||
|
|
|
||||||
|
|
@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
|
||||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
for chunk_index in sorted_chunk_indices:
|
|
||||||
segment_query = db.session.query(DocumentSegment).where(
|
segment_query_stmt = db.session.query(DocumentSegment).where(
|
||||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
|
||||||
)
|
)
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
|
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||||
segment = segment_query.first()
|
|
||||||
|
segments = db.session.execute(segment_query_stmt).scalars().all()
|
||||||
|
segment_map = {segment.index_node_id: segment for segment in segments}
|
||||||
|
for chunk_index in sorted_chunk_indices:
|
||||||
|
segment = segment_map.get(chunk_index)
|
||||||
|
|
||||||
if segment:
|
if segment:
|
||||||
documents.append(
|
documents.append(
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,13 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, load_only
|
from sqlalchemy.orm import Session, load_only
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.embedding.retrieval import RetrievalSegments
|
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||||
from core.rag.entities.metadata_entities import MetadataCondition
|
from core.rag.entities.metadata_entities import MetadataCondition
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
|
|
@ -138,37 +139,47 @@ class RetrievalService:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
||||||
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
|
"""Deduplicate documents in O(n) while preserving first-seen order.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
|
||||||
|
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
|
||||||
|
- For non-dify documents (or dify without doc_id): deduplicate by content key
|
||||||
|
(provider, page_content), keeping the first occurrence.
|
||||||
|
"""
|
||||||
if not documents:
|
if not documents:
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
unique_documents = []
|
# Map of dedup key -> chosen Document
|
||||||
seen_doc_ids = set()
|
chosen: dict[tuple, Document] = {}
|
||||||
|
# Preserve the order of first appearance of each dedup key
|
||||||
|
order: list[tuple] = []
|
||||||
|
|
||||||
for document in documents:
|
for doc in documents:
|
||||||
# For dify provider documents, use doc_id for deduplication
|
is_dify = doc.provider == "dify"
|
||||||
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
|
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
|
||||||
doc_id = document.metadata["doc_id"]
|
|
||||||
if doc_id not in seen_doc_ids:
|
if is_dify and doc_id:
|
||||||
seen_doc_ids.add(doc_id)
|
key = ("dify", doc_id)
|
||||||
unique_documents.append(document)
|
if key not in chosen:
|
||||||
# If duplicate, keep the one with higher score
|
chosen[key] = doc
|
||||||
elif "score" in document.metadata:
|
order.append(key)
|
||||||
# Find existing document with same doc_id and compare scores
|
|
||||||
for i, existing_doc in enumerate(unique_documents):
|
|
||||||
if (
|
|
||||||
existing_doc.metadata
|
|
||||||
and existing_doc.metadata.get("doc_id") == doc_id
|
|
||||||
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
|
|
||||||
):
|
|
||||||
unique_documents[i] = document
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
# For non-dify documents, use content-based deduplication
|
# Only replace if the new one has a score and it's strictly higher
|
||||||
if document not in unique_documents:
|
if "score" in doc.metadata:
|
||||||
unique_documents.append(document)
|
new_score = float(doc.metadata.get("score", 0.0))
|
||||||
|
old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
|
||||||
|
if new_score > old_score:
|
||||||
|
chosen[key] = doc
|
||||||
|
else:
|
||||||
|
# Content-based dedup for non-dify or dify without doc_id
|
||||||
|
content_key = (doc.provider or "dify", doc.page_content)
|
||||||
|
if content_key not in chosen:
|
||||||
|
chosen[content_key] = doc
|
||||||
|
order.append(content_key)
|
||||||
|
# If duplicate content appears, we keep the first occurrence (no score comparison)
|
||||||
|
|
||||||
return unique_documents
|
return [chosen[k] for k in order]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||||
|
|
@ -370,13 +381,13 @@ class RetrievalService:
|
||||||
records = []
|
records = []
|
||||||
include_segment_ids = set()
|
include_segment_ids = set()
|
||||||
segment_child_map = {}
|
segment_child_map = {}
|
||||||
segment_file_map = {}
|
|
||||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
valid_dataset_documents = {}
|
||||||
# Process documents
|
image_doc_ids: list[Any] = []
|
||||||
|
child_index_node_ids = []
|
||||||
|
index_node_ids = []
|
||||||
|
doc_to_document_map = {}
|
||||||
for document in documents:
|
for document in documents:
|
||||||
segment_id = None
|
|
||||||
attachment_info = None
|
|
||||||
child_chunk = None
|
|
||||||
document_id = document.metadata.get("document_id")
|
document_id = document.metadata.get("document_id")
|
||||||
if document_id not in dataset_documents:
|
if document_id not in dataset_documents:
|
||||||
continue
|
continue
|
||||||
|
|
@ -384,157 +395,162 @@ class RetrievalService:
|
||||||
dataset_document = dataset_documents[document_id]
|
dataset_document = dataset_documents[document_id]
|
||||||
if not dataset_document:
|
if not dataset_document:
|
||||||
continue
|
continue
|
||||||
|
valid_dataset_documents[document_id] = dataset_document
|
||||||
|
|
||||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
# Handle parent-child documents
|
doc_id = document.metadata.get("doc_id") or ""
|
||||||
|
doc_to_document_map[doc_id] = document
|
||||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||||
attachment_info_dict = cls.get_segment_attachment_info(
|
image_doc_ids.append(doc_id)
|
||||||
dataset_document.dataset_id,
|
|
||||||
dataset_document.tenant_id,
|
|
||||||
document.metadata.get("doc_id") or "",
|
|
||||||
session,
|
|
||||||
)
|
|
||||||
if attachment_info_dict:
|
|
||||||
attachment_info = attachment_info_dict["attachment_info"]
|
|
||||||
segment_id = attachment_info_dict["segment_id"]
|
|
||||||
else:
|
else:
|
||||||
child_index_node_id = document.metadata.get("doc_id")
|
child_index_node_ids.append(doc_id)
|
||||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
else:
|
||||||
child_chunk = session.scalar(child_chunk_stmt)
|
doc_id = document.metadata.get("doc_id") or ""
|
||||||
|
doc_to_document_map[doc_id] = document
|
||||||
|
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||||
|
image_doc_ids.append(doc_id)
|
||||||
|
else:
|
||||||
|
index_node_ids.append(doc_id)
|
||||||
|
|
||||||
if not child_chunk:
|
image_doc_ids = [i for i in image_doc_ids if i]
|
||||||
continue
|
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||||
segment_id = child_chunk.segment_id
|
index_node_ids = [i for i in index_node_ids if i]
|
||||||
|
|
||||||
if not segment_id:
|
segment_ids: list[str] = []
|
||||||
continue
|
index_node_segments: list[DocumentSegment] = []
|
||||||
|
segments: list[DocumentSegment] = []
|
||||||
|
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||||
|
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||||
|
doc_segment_map: dict[str, list[str]] = {}
|
||||||
|
|
||||||
segment = (
|
with session_factory.create_session() as session:
|
||||||
session.query(DocumentSegment)
|
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||||
.where(
|
|
||||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
for attachment in attachments:
|
||||||
|
segment_ids.append(attachment["segment_id"])
|
||||||
|
if attachment["segment_id"] in attachment_map:
|
||||||
|
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
|
||||||
|
else:
|
||||||
|
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
|
||||||
|
if attachment["segment_id"] in doc_segment_map:
|
||||||
|
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||||
|
else:
|
||||||
|
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||||
|
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||||
|
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||||
|
|
||||||
|
for i in child_index_nodes:
|
||||||
|
segment_ids.append(i.segment_id)
|
||||||
|
if i.segment_id in child_chunk_map:
|
||||||
|
child_chunk_map[i.segment_id].append(i)
|
||||||
|
else:
|
||||||
|
child_chunk_map[i.segment_id] = [i]
|
||||||
|
if i.segment_id in doc_segment_map:
|
||||||
|
doc_segment_map[i.segment_id].append(i.index_node_id)
|
||||||
|
else:
|
||||||
|
doc_segment_map[i.segment_id] = [i.index_node_id]
|
||||||
|
|
||||||
|
if index_node_ids:
|
||||||
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
DocumentSegment.enabled == True,
|
DocumentSegment.enabled == True,
|
||||||
DocumentSegment.status == "completed",
|
DocumentSegment.status == "completed",
|
||||||
DocumentSegment.id == segment_id,
|
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||||
)
|
)
|
||||||
.first()
|
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||||
|
for index_node_segment in index_node_segments:
|
||||||
|
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||||
|
if segment_ids:
|
||||||
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
|
DocumentSegment.enabled == True,
|
||||||
|
DocumentSegment.status == "completed",
|
||||||
|
DocumentSegment.id.in_(segment_ids),
|
||||||
)
|
)
|
||||||
|
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||||
|
|
||||||
if not segment:
|
if index_node_segments:
|
||||||
continue
|
segments.extend(index_node_segments)
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||||
|
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||||
|
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||||
|
|
||||||
|
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
if segment.id not in include_segment_ids:
|
if segment.id not in include_segment_ids:
|
||||||
include_segment_ids.add(segment.id)
|
include_segment_ids.add(segment.id)
|
||||||
if child_chunk:
|
if child_chunks or attachment_infos:
|
||||||
|
child_chunk_details = []
|
||||||
|
max_score = 0.0
|
||||||
|
for child_chunk in child_chunks:
|
||||||
|
document = doc_to_document_map[child_chunk.index_node_id]
|
||||||
child_chunk_detail = {
|
child_chunk_detail = {
|
||||||
"id": child_chunk.id,
|
"id": child_chunk.id,
|
||||||
"content": child_chunk.content,
|
"content": child_chunk.content,
|
||||||
"position": child_chunk.position,
|
"position": child_chunk.position,
|
||||||
"score": document.metadata.get("score", 0.0),
|
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||||
}
|
}
|
||||||
|
child_chunk_details.append(child_chunk_detail)
|
||||||
|
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||||
|
for attachment_info in attachment_infos:
|
||||||
|
file_document = doc_to_document_map[attachment_info["id"]]
|
||||||
|
max_score = max(
|
||||||
|
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
map_detail = {
|
map_detail = {
|
||||||
"max_score": document.metadata.get("score", 0.0),
|
"max_score": max_score,
|
||||||
"child_chunks": [child_chunk_detail],
|
"child_chunks": child_chunk_details,
|
||||||
}
|
}
|
||||||
segment_child_map[segment.id] = map_detail
|
segment_child_map[segment.id] = map_detail
|
||||||
record = {
|
record: dict[str, Any] = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
}
|
}
|
||||||
if attachment_info:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
records.append(record)
|
records.append(record)
|
||||||
else:
|
else:
|
||||||
if child_chunk:
|
|
||||||
child_chunk_detail = {
|
|
||||||
"id": child_chunk.id,
|
|
||||||
"content": child_chunk.content,
|
|
||||||
"position": child_chunk.position,
|
|
||||||
"score": document.metadata.get("score", 0.0),
|
|
||||||
}
|
|
||||||
if segment.id in segment_child_map:
|
|
||||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
|
||||||
segment_child_map[segment.id]["max_score"] = max(
|
|
||||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
segment_child_map[segment.id] = {
|
|
||||||
"max_score": document.metadata.get("score", 0.0),
|
|
||||||
"child_chunks": [child_chunk_detail],
|
|
||||||
}
|
|
||||||
if attachment_info:
|
|
||||||
if segment.id in segment_file_map:
|
|
||||||
segment_file_map[segment.id].append(attachment_info)
|
|
||||||
else:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
else:
|
|
||||||
# Handle normal documents
|
|
||||||
segment = None
|
|
||||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
|
||||||
attachment_info_dict = cls.get_segment_attachment_info(
|
|
||||||
dataset_document.dataset_id,
|
|
||||||
dataset_document.tenant_id,
|
|
||||||
document.metadata.get("doc_id") or "",
|
|
||||||
session,
|
|
||||||
)
|
|
||||||
if attachment_info_dict:
|
|
||||||
attachment_info = attachment_info_dict["attachment_info"]
|
|
||||||
segment_id = attachment_info_dict["segment_id"]
|
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
|
||||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
|
||||||
DocumentSegment.enabled == True,
|
|
||||||
DocumentSegment.status == "completed",
|
|
||||||
DocumentSegment.id == segment_id,
|
|
||||||
)
|
|
||||||
segment = session.scalar(document_segment_stmt)
|
|
||||||
if segment:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
else:
|
|
||||||
index_node_id = document.metadata.get("doc_id")
|
|
||||||
if not index_node_id:
|
|
||||||
continue
|
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
|
||||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
|
||||||
DocumentSegment.enabled == True,
|
|
||||||
DocumentSegment.status == "completed",
|
|
||||||
DocumentSegment.index_node_id == index_node_id,
|
|
||||||
)
|
|
||||||
segment = session.scalar(document_segment_stmt)
|
|
||||||
|
|
||||||
if not segment:
|
|
||||||
continue
|
|
||||||
if segment.id not in include_segment_ids:
|
if segment.id not in include_segment_ids:
|
||||||
include_segment_ids.add(segment.id)
|
include_segment_ids.add(segment.id)
|
||||||
|
max_score = 0.0
|
||||||
|
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||||
|
if segment_document:
|
||||||
|
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||||
|
for attachment_info in attachment_infos:
|
||||||
|
file_doc = doc_to_document_map.get(attachment_info["id"])
|
||||||
|
if file_doc:
|
||||||
|
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||||
record = {
|
record = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
"score": document.metadata.get("score"), # type: ignore
|
"score": max_score,
|
||||||
}
|
}
|
||||||
if attachment_info:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
records.append(record)
|
records.append(record)
|
||||||
else:
|
|
||||||
if attachment_info:
|
|
||||||
attachment_infos = segment_file_map.get(segment.id, [])
|
|
||||||
if attachment_info not in attachment_infos:
|
|
||||||
attachment_infos.append(attachment_info)
|
|
||||||
segment_file_map[segment.id] = attachment_infos
|
|
||||||
|
|
||||||
# Add child chunks information to records
|
# Add child chunks information to records
|
||||||
for record in records:
|
for record in records:
|
||||||
if record["segment"].id in segment_child_map:
|
if record["segment"].id in segment_child_map:
|
||||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||||
if record["segment"].id in segment_file_map:
|
if record["segment"].id in attachment_map:
|
||||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||||
|
|
||||||
result = []
|
result: list[RetrievalSegments] = []
|
||||||
for record in records:
|
for record in records:
|
||||||
# Extract segment
|
# Extract segment
|
||||||
segment = record["segment"]
|
segment = record["segment"]
|
||||||
|
|
||||||
# Extract child_chunks, ensuring it's a list or None
|
# Extract child_chunks, ensuring it's a list or None
|
||||||
child_chunks = record.get("child_chunks")
|
raw_child_chunks = record.get("child_chunks")
|
||||||
if not isinstance(child_chunks, list):
|
child_chunks_list: list[RetrievalChildChunk] | None = None
|
||||||
child_chunks = None
|
if isinstance(raw_child_chunks, list):
|
||||||
|
# Sort by score descending
|
||||||
|
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
|
||||||
|
child_chunks_list = [
|
||||||
|
RetrievalChildChunk(
|
||||||
|
id=chunk["id"],
|
||||||
|
content=chunk["content"],
|
||||||
|
score=chunk.get("score", 0.0),
|
||||||
|
position=chunk["position"],
|
||||||
|
)
|
||||||
|
for chunk in sorted_chunks
|
||||||
|
]
|
||||||
|
|
||||||
# Extract files, ensuring it's a list or None
|
# Extract files, ensuring it's a list or None
|
||||||
files = record.get("files")
|
files = record.get("files")
|
||||||
|
|
@ -551,11 +567,11 @@ class RetrievalService:
|
||||||
|
|
||||||
# Create RetrievalSegments object
|
# Create RetrievalSegments object
|
||||||
retrieval_segment = RetrievalSegments(
|
retrieval_segment = RetrievalSegments(
|
||||||
segment=segment, child_chunks=child_chunks, score=score, files=files
|
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||||
)
|
)
|
||||||
result.append(retrieval_segment)
|
result.append(retrieval_segment)
|
||||||
|
|
||||||
return result
|
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
@ -565,6 +581,8 @@ class RetrievalService:
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
retrieval_method: RetrievalMethod,
|
retrieval_method: RetrievalMethod,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
|
all_documents: list[Document],
|
||||||
|
exceptions: list[str],
|
||||||
query: str | None = None,
|
query: str | None = None,
|
||||||
top_k: int = 4,
|
top_k: int = 4,
|
||||||
score_threshold: float | None = 0.0,
|
score_threshold: float | None = 0.0,
|
||||||
|
|
@ -573,8 +591,6 @@ class RetrievalService:
|
||||||
weights: dict | None = None,
|
weights: dict | None = None,
|
||||||
document_ids_filter: list[str] | None = None,
|
document_ids_filter: list[str] | None = None,
|
||||||
attachment_id: str | None = None,
|
attachment_id: str | None = None,
|
||||||
all_documents: list[Document] = [],
|
|
||||||
exceptions: list[str] = [],
|
|
||||||
):
|
):
|
||||||
if not query and not attachment_id:
|
if not query and not attachment_id:
|
||||||
return
|
return
|
||||||
|
|
@ -696,3 +712,37 @@ class RetrievalService:
|
||||||
}
|
}
|
||||||
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||||
|
attachment_infos = []
|
||||||
|
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||||
|
if upload_files:
|
||||||
|
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||||
|
attachment_bindings = (
|
||||||
|
session.query(SegmentAttachmentBinding)
|
||||||
|
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
|
||||||
|
|
||||||
|
if attachment_bindings:
|
||||||
|
for upload_file in upload_files:
|
||||||
|
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||||
|
attachment_info = {
|
||||||
|
"id": upload_file.id,
|
||||||
|
"name": upload_file.name,
|
||||||
|
"extension": "." + upload_file.extension,
|
||||||
|
"mime_type": upload_file.mime_type,
|
||||||
|
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||||
|
"size": upload_file.size,
|
||||||
|
}
|
||||||
|
if attachment_binding:
|
||||||
|
attachment_infos.append(
|
||||||
|
{
|
||||||
|
"attachment_id": attachment_binding.attachment_id,
|
||||||
|
"attachment_info": attachment_info,
|
||||||
|
"segment_id": attachment_binding.segment_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return attachment_infos
|
||||||
|
|
|
||||||
|
|
@ -289,7 +289,8 @@ class OracleVector(BaseVector):
|
||||||
words = pseg.cut(query)
|
words = pseg.cut(query)
|
||||||
current_entity = ""
|
current_entity = ""
|
||||||
for word, pos in words:
|
for word, pos in words:
|
||||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
# `nr`: Person, `ns`: Location, `nt`: Organization
|
||||||
|
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
|
||||||
current_entity += word
|
current_entity += word
|
||||||
else:
|
else:
|
||||||
if current_entity:
|
if current_entity:
|
||||||
|
|
|
||||||
|
|
@ -255,7 +255,10 @@ class PGVector(BaseVector):
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
if not cur.fetchone():
|
||||||
|
cur.execute("CREATE EXTENSION vector")
|
||||||
|
|
||||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||||
# PG hnsw index only support 2000 dimension or less
|
# PG hnsw index only support 2000 dimension or less
|
||||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||||
|
|
|
||||||
|
|
@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
|
||||||
|
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||||
# Vastbase 支持的向量维度取值范围为 [1,16000]
|
# Vastbase supports vector dimensions in the range [1, 16,000]
|
||||||
if dimension <= 16000:
|
if dimension <= 16000:
|
||||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class FirecrawlApp:
|
||||||
}
|
}
|
||||||
if params:
|
if params:
|
||||||
json_data.update(params)
|
json_data.update(params)
|
||||||
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
|
response = self._post_request(self._build_url("v2/scrape"), json_data, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
data = response_data["data"]
|
data = response_data["data"]
|
||||||
|
|
@ -42,7 +42,7 @@ class FirecrawlApp:
|
||||||
json_data = {"url": url}
|
json_data = {"url": url}
|
||||||
if params:
|
if params:
|
||||||
json_data.update(params)
|
json_data.update(params)
|
||||||
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
|
response = self._post_request(self._build_url("v2/crawl"), json_data, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
||||||
job_id = response.json().get("id")
|
job_id = response.json().get("id")
|
||||||
|
|
@ -58,7 +58,7 @@ class FirecrawlApp:
|
||||||
if params:
|
if params:
|
||||||
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
|
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
|
||||||
json_data.update(params)
|
json_data.update(params)
|
||||||
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
|
response = self._post_request(self._build_url("v2/map"), json_data, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return cast(dict[str, Any], response.json())
|
return cast(dict[str, Any], response.json())
|
||||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||||
|
|
@ -69,7 +69,7 @@ class FirecrawlApp:
|
||||||
|
|
||||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||||
headers = self._prepare_headers()
|
headers = self._prepare_headers()
|
||||||
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
|
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
crawl_status_response = response.json()
|
crawl_status_response = response.json()
|
||||||
if crawl_status_response.get("status") == "completed":
|
if crawl_status_response.get("status") == "completed":
|
||||||
|
|
@ -120,6 +120,10 @@ class FirecrawlApp:
|
||||||
def _prepare_headers(self) -> dict[str, Any]:
|
def _prepare_headers(self) -> dict[str, Any]:
|
||||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
|
|
||||||
|
def _build_url(self, path: str) -> str:
|
||||||
|
# ensure exactly one slash between base and path, regardless of user-provided base_url
|
||||||
|
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
|
||||||
|
|
||||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
response = httpx.post(url, headers=headers, json=data)
|
response = httpx.post(url, headers=headers, json=data)
|
||||||
|
|
@ -139,7 +143,11 @@ class FirecrawlApp:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_error(self, response, action):
|
def _handle_error(self, response, action):
|
||||||
error_message = response.json().get("error", "Unknown error occurred")
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
error_message = payload.get("error") or payload.get("message") or response.text or "Unknown error occurred"
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_message = response.text or "Unknown error occurred"
|
||||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
||||||
|
|
||||||
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
|
|
@ -160,7 +168,7 @@ class FirecrawlApp:
|
||||||
}
|
}
|
||||||
if params:
|
if params:
|
||||||
json_data.update(params)
|
json_data.update(params)
|
||||||
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
|
response = self._post_request(self._build_url("v2/search"), json_data, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
if not response_data.get("success"):
|
if not response_data.get("success"):
|
||||||
|
|
|
||||||
|
|
@ -48,13 +48,21 @@ class NotionExtractor(BaseExtractor):
|
||||||
if notion_access_token:
|
if notion_access_token:
|
||||||
self._notion_access_token = notion_access_token
|
self._notion_access_token = notion_access_token
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
|
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
|
||||||
if not self._notion_access_token:
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
(
|
||||||
|
"Failed to get Notion access token from datasource credentials: %s, "
|
||||||
|
"falling back to environment variable NOTION_INTEGRATION_TOKEN"
|
||||||
|
),
|
||||||
|
e,
|
||||||
|
)
|
||||||
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
||||||
if integration_token is None:
|
if integration_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
|
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
|
||||||
)
|
) from e
|
||||||
|
|
||||||
self._notion_access_token = integration_token
|
self._notion_access_token = integration_token
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,7 @@ class BaseIndexProcessor(ABC):
|
||||||
|
|
||||||
if not filename:
|
if not filename:
|
||||||
parsed_url = urlparse(image_url)
|
parsed_url = urlparse(image_url)
|
||||||
# unquote 处理 URL 中的中文
|
# Decode percent-encoded characters in the URL path.
|
||||||
path = unquote(parsed_url.path)
|
path = unquote(parsed_url.path)
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from sqlalchemy import and_, or_, select
|
from sqlalchemy import and_, literal, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.entities import (
|
from core.app.app_config.entities import (
|
||||||
|
|
@ -151,20 +151,14 @@ class DatasetRetrieval:
|
||||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||||
planning_strategy = PlanningStrategy.ROUTER
|
planning_strategy = PlanningStrategy.ROUTER
|
||||||
available_datasets = []
|
available_datasets = []
|
||||||
for dataset_id in dataset_ids:
|
|
||||||
# get dataset from dataset id
|
|
||||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
|
||||||
dataset = db.session.scalar(dataset_stmt)
|
|
||||||
|
|
||||||
# pass if dataset is not available
|
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||||
if not dataset:
|
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||||
|
for dataset in datasets:
|
||||||
|
if dataset.available_document_count == 0 and dataset.provider != "external":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# pass if dataset is not available
|
|
||||||
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
|
|
||||||
continue
|
|
||||||
|
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
|
|
||||||
if inputs:
|
if inputs:
|
||||||
inputs = {key: str(value) for key, value in inputs.items()}
|
inputs = {key: str(value) for key, value in inputs.items()}
|
||||||
else:
|
else:
|
||||||
|
|
@ -282,26 +276,35 @@ class DatasetRetrieval:
|
||||||
)
|
)
|
||||||
context_files.append(attachment_info)
|
context_files.append(attachment_info)
|
||||||
if show_retrieve_source:
|
if show_retrieve_source:
|
||||||
for record in records:
|
dataset_ids = [record.segment.dataset_id for record in records]
|
||||||
segment = record.segment
|
document_ids = [record.segment.document_id for record in records]
|
||||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
|
||||||
dataset_document_stmt = select(DatasetDocument).where(
|
dataset_document_stmt = select(DatasetDocument).where(
|
||||||
DatasetDocument.id == segment.document_id,
|
DatasetDocument.id.in_(document_ids),
|
||||||
DatasetDocument.enabled == True,
|
DatasetDocument.enabled == True,
|
||||||
DatasetDocument.archived == False,
|
DatasetDocument.archived == False,
|
||||||
)
|
)
|
||||||
document = db.session.scalar(dataset_document_stmt)
|
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
|
||||||
if dataset and document:
|
dataset_stmt = select(Dataset).where(
|
||||||
|
Dataset.id.in_(dataset_ids),
|
||||||
|
)
|
||||||
|
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||||
|
dataset_map = {i.id: i for i in datasets}
|
||||||
|
document_map = {i.id: i for i in documents}
|
||||||
|
for record in records:
|
||||||
|
segment = record.segment
|
||||||
|
dataset_item = dataset_map.get(segment.dataset_id)
|
||||||
|
document_item = document_map.get(segment.document_id)
|
||||||
|
if dataset_item and document_item:
|
||||||
source = RetrievalSourceMetadata(
|
source = RetrievalSourceMetadata(
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset_item.id,
|
||||||
dataset_name=dataset.name,
|
dataset_name=dataset_item.name,
|
||||||
document_id=document.id,
|
document_id=document_item.id,
|
||||||
document_name=document.name,
|
document_name=document_item.name,
|
||||||
data_source_type=document.data_source_type,
|
data_source_type=document_item.data_source_type,
|
||||||
segment_id=segment.id,
|
segment_id=segment.id,
|
||||||
retriever_from=invoke_from.to_source(),
|
retriever_from=invoke_from.to_source(),
|
||||||
score=record.score or 0.0,
|
score=record.score or 0.0,
|
||||||
doc_metadata=document.doc_metadata,
|
doc_metadata=document_item.doc_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if invoke_from.to_source() == "dev":
|
if invoke_from.to_source() == "dev":
|
||||||
|
|
@ -1033,7 +1036,7 @@ class DatasetRetrieval:
|
||||||
if automatic_metadata_filters:
|
if automatic_metadata_filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||||
self._process_metadata_filter_func(
|
self.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
filter.get("condition"), # type: ignore
|
filter.get("condition"), # type: ignore
|
||||||
filter.get("metadata_name"), # type: ignore
|
filter.get("metadata_name"), # type: ignore
|
||||||
|
|
@ -1069,7 +1072,7 @@ class DatasetRetrieval:
|
||||||
value=expected_value,
|
value=expected_value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
filters = self._process_metadata_filter_func(
|
filters = self.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
condition.comparison_operator,
|
condition.comparison_operator,
|
||||||
metadata_name,
|
metadata_name,
|
||||||
|
|
@ -1165,8 +1168,9 @@ class DatasetRetrieval:
|
||||||
return None
|
return None
|
||||||
return automatic_metadata_filters
|
return automatic_metadata_filters
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
@classmethod
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
def process_metadata_filter_func(
|
||||||
|
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||||
):
|
):
|
||||||
if value is None and condition not in ("empty", "not empty"):
|
if value is None and condition not in ("empty", "not empty"):
|
||||||
return filters
|
return filters
|
||||||
|
|
@ -1215,6 +1219,20 @@ class DatasetRetrieval:
|
||||||
|
|
||||||
case "≥" | ">=":
|
case "≥" | ">=":
|
||||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||||
|
case "in" | "not in":
|
||||||
|
if isinstance(value, str):
|
||||||
|
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
value_list = [str(v) for v in value if v is not None]
|
||||||
|
else:
|
||||||
|
value_list = [str(value)] if value is not None else []
|
||||||
|
|
||||||
|
if not value_list:
|
||||||
|
# `field in []` is False, `field not in []` is True
|
||||||
|
filters.append(literal(condition == "not in"))
|
||||||
|
else:
|
||||||
|
op = json_field.in_ if condition == "in" else json_field.notin_
|
||||||
|
filters.append(op(value_list))
|
||||||
case _:
|
case _:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -153,11 +153,11 @@ class ToolInvokeMessage(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_variable_value(cls, values):
|
def transform_variable_value(cls, values):
|
||||||
"""
|
"""
|
||||||
Only basic types and lists are allowed.
|
Only basic types, lists, and None are allowed.
|
||||||
"""
|
"""
|
||||||
value = values.get("variable_value")
|
value = values.get("variable_value")
|
||||||
if not isinstance(value, dict | list | str | int | float | bool):
|
if value is not None and not isinstance(value, dict | list | str | int | float | bool):
|
||||||
raise ValueError("Only basic types and lists are allowed.")
|
raise ValueError("Only basic types, lists, and None are allowed.")
|
||||||
|
|
||||||
# if stream is true, the value must be a string
|
# if stream is true, the value must be a string
|
||||||
if values.get("stream"):
|
if values.get("stream"):
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,15 @@ from typing import Any
|
||||||
|
|
||||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPConnectionError
|
from core.mcp.error import MCPConnectionError
|
||||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
from core.mcp.types import (
|
||||||
|
AudioContent,
|
||||||
|
BlobResourceContents,
|
||||||
|
CallToolResult,
|
||||||
|
EmbeddedResource,
|
||||||
|
ImageContent,
|
||||||
|
TextContent,
|
||||||
|
TextResourceContents,
|
||||||
|
)
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||||
|
|
@ -53,10 +61,19 @@ class MCPTool(Tool):
|
||||||
for content in result.content:
|
for content in result.content:
|
||||||
if isinstance(content, TextContent):
|
if isinstance(content, TextContent):
|
||||||
yield from self._process_text_content(content)
|
yield from self._process_text_content(content)
|
||||||
elif isinstance(content, ImageContent):
|
elif isinstance(content, ImageContent | AudioContent):
|
||||||
yield self._process_image_content(content)
|
yield self.create_blob_message(
|
||||||
elif isinstance(content, AudioContent):
|
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
|
||||||
yield self._process_audio_content(content)
|
)
|
||||||
|
elif isinstance(content, EmbeddedResource):
|
||||||
|
resource = content.resource
|
||||||
|
if isinstance(resource, TextResourceContents):
|
||||||
|
yield self.create_text_message(resource.text)
|
||||||
|
elif isinstance(resource, BlobResourceContents):
|
||||||
|
mime_type = resource.mimeType or "application/octet-stream"
|
||||||
|
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
|
||||||
|
else:
|
||||||
|
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
|
||||||
else:
|
else:
|
||||||
logger.warning("Unsupported content type=%s", type(content))
|
logger.warning("Unsupported content type=%s", type(content))
|
||||||
|
|
||||||
|
|
@ -101,14 +118,6 @@ class MCPTool(Tool):
|
||||||
for item in json_list:
|
for item in json_list:
|
||||||
yield self.create_json_message(item)
|
yield self.create_json_message(item)
|
||||||
|
|
||||||
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
|
|
||||||
"""Process image content and return a blob message."""
|
|
||||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
|
||||||
|
|
||||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
|
||||||
"""Process audio content and return a blob message."""
|
|
||||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||||
return MCPTool(
|
return MCPTool(
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.plugin.entities.parameters import PluginParameterOption
|
from core.plugin.entities.parameters import PluginParameterOption
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
|
|
@ -47,33 +48,30 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
|
app = session.get(App, db_provider.app_id)
|
||||||
if not provider:
|
|
||||||
raise ValueError("workflow provider not found")
|
|
||||||
app = session.get(App, provider.app_id)
|
|
||||||
if not app:
|
if not app:
|
||||||
raise ValueError("app not found")
|
raise ValueError("app not found")
|
||||||
|
|
||||||
user = session.get(Account, provider.user_id) if provider.user_id else None
|
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||||
|
|
||||||
controller = WorkflowToolProviderController(
|
controller = WorkflowToolProviderController(
|
||||||
entity=ToolProviderEntity(
|
entity=ToolProviderEntity(
|
||||||
identity=ToolProviderIdentity(
|
identity=ToolProviderIdentity(
|
||||||
author=user.name if user else "",
|
author=user.name if user else "",
|
||||||
name=provider.label,
|
name=db_provider.label,
|
||||||
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
|
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||||
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
|
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||||
icon=provider.icon,
|
icon=db_provider.icon,
|
||||||
),
|
),
|
||||||
credentials_schema=[],
|
credentials_schema=[],
|
||||||
plugin_id=None,
|
plugin_id=None,
|
||||||
),
|
),
|
||||||
provider_id=provider.id or "",
|
provider_id="",
|
||||||
)
|
)
|
||||||
|
|
||||||
controller.tools = [
|
controller.tools = [
|
||||||
controller._get_db_provider_tool(provider, app, session=session, user=user),
|
controller._get_db_provider_tool(db_provider, app, session=session, user=user),
|
||||||
]
|
]
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
|
|
|
||||||
|
|
@ -67,12 +67,16 @@ def create_trigger_provider_encrypter_for_subscription(
|
||||||
|
|
||||||
|
|
||||||
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
|
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
|
||||||
cache = TriggerProviderCredentialsCache(
|
TriggerProviderCredentialsCache(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
credential_id=subscription_id,
|
credential_id=subscription_id,
|
||||||
)
|
).delete()
|
||||||
cache.delete()
|
TriggerProviderPropertiesCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
|
||||||
def create_trigger_provider_encrypter_for_properties(
|
def create_trigger_provider_encrypter_for_properties(
|
||||||
|
|
|
||||||
|
|
@ -247,6 +247,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||||
DATASOURCE_INFO = "datasource_info"
|
DATASOURCE_INFO = "datasource_info"
|
||||||
|
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionStatus(StrEnum):
|
class WorkflowNodeExecutionStatus(StrEnum):
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from sqlalchemy import and_, func, literal, or_, select
|
from sqlalchemy import and_, func, or_, select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||||
|
|
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
if automatic_metadata_filters:
|
if automatic_metadata_filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||||
self._process_metadata_filter_func(
|
DatasetRetrieval.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
filter.get("condition", ""),
|
filter.get("condition", ""),
|
||||||
filter.get("metadata_name", ""),
|
filter.get("metadata_name", ""),
|
||||||
|
|
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
value=expected_value,
|
value=expected_value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
filters = self._process_metadata_filter_func(
|
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
condition.comparison_operator,
|
condition.comparison_operator,
|
||||||
metadata_name,
|
metadata_name,
|
||||||
|
|
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
return [], usage
|
return [], usage
|
||||||
return automatic_metadata_filters, usage
|
return automatic_metadata_filters, usage
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
|
||||||
) -> list[Any]:
|
|
||||||
if value is None and condition not in ("empty", "not empty"):
|
|
||||||
return filters
|
|
||||||
|
|
||||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
|
||||||
|
|
||||||
match condition:
|
|
||||||
case "contains":
|
|
||||||
filters.append(json_field.like(f"%{value}%"))
|
|
||||||
|
|
||||||
case "not contains":
|
|
||||||
filters.append(json_field.notlike(f"%{value}%"))
|
|
||||||
|
|
||||||
case "start with":
|
|
||||||
filters.append(json_field.like(f"{value}%"))
|
|
||||||
|
|
||||||
case "end with":
|
|
||||||
filters.append(json_field.like(f"%{value}"))
|
|
||||||
case "in":
|
|
||||||
if isinstance(value, str):
|
|
||||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
|
||||||
elif isinstance(value, (list, tuple)):
|
|
||||||
value_list = [str(v) for v in value if v is not None]
|
|
||||||
else:
|
|
||||||
value_list = [str(value)] if value is not None else []
|
|
||||||
|
|
||||||
if not value_list:
|
|
||||||
filters.append(literal(False))
|
|
||||||
else:
|
|
||||||
filters.append(json_field.in_(value_list))
|
|
||||||
|
|
||||||
case "not in":
|
|
||||||
if isinstance(value, str):
|
|
||||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
|
||||||
elif isinstance(value, (list, tuple)):
|
|
||||||
value_list = [str(v) for v in value if v is not None]
|
|
||||||
else:
|
|
||||||
value_list = [str(value)] if value is not None else []
|
|
||||||
|
|
||||||
if not value_list:
|
|
||||||
filters.append(literal(True))
|
|
||||||
else:
|
|
||||||
filters.append(json_field.notin_(value_list))
|
|
||||||
|
|
||||||
case "is" | "=":
|
|
||||||
if isinstance(value, str):
|
|
||||||
filters.append(json_field == value)
|
|
||||||
elif isinstance(value, (int, float)):
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
|
||||||
|
|
||||||
case "is not" | "≠":
|
|
||||||
if isinstance(value, str):
|
|
||||||
filters.append(json_field != value)
|
|
||||||
elif isinstance(value, (int, float)):
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
|
||||||
|
|
||||||
case "empty":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
|
||||||
|
|
||||||
case "not empty":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
|
||||||
|
|
||||||
case "before" | "<":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
|
||||||
|
|
||||||
case "after" | ">":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
|
||||||
|
|
||||||
case "≤" | "<=":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
|
||||||
|
|
||||||
case "≥" | ">=":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
|
||||||
|
|
||||||
case _:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return filters
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||||
|
|
@ -96,3 +97,8 @@ class LoopState(BaseLoopState):
|
||||||
Get current output.
|
Get current output.
|
||||||
"""
|
"""
|
||||||
return self.current_output
|
return self.current_output
|
||||||
|
|
||||||
|
|
||||||
|
class LoopCompletedReason(StrEnum):
|
||||||
|
LOOP_BREAK = "loop_break"
|
||||||
|
LOOP_COMPLETED = "loop_completed"
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from core.workflow.node_events import (
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
|
@ -96,6 +96,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
loop_duration_map: dict[str, float] = {}
|
loop_duration_map: dict[str, float] = {}
|
||||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||||
loop_usage = LLMUsage.empty_usage()
|
loop_usage = LLMUsage.empty_usage()
|
||||||
|
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
||||||
|
|
||||||
# Start Loop event
|
# Start Loop event
|
||||||
yield LoopStartedEvent(
|
yield LoopStartedEvent(
|
||||||
|
|
@ -118,6 +119,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
loop_count = 0
|
loop_count = 0
|
||||||
|
|
||||||
for i in range(loop_count):
|
for i in range(loop_count):
|
||||||
|
# Clear stale variables from previous loop iterations to avoid streaming old values
|
||||||
|
self._clear_loop_subgraph_variables(loop_node_ids)
|
||||||
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
||||||
|
|
||||||
loop_start_time = naive_utc_now()
|
loop_start_time = naive_utc_now()
|
||||||
|
|
@ -177,7 +180,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: (
|
||||||
|
LoopCompletedReason.LOOP_BREAK
|
||||||
|
if reach_break_condition
|
||||||
|
else LoopCompletedReason.LOOP_COMPLETED.value
|
||||||
|
),
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
|
|
@ -274,6 +281,17 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
|
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
|
||||||
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
|
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
|
||||||
|
|
||||||
|
def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None:
|
||||||
|
"""
|
||||||
|
Remove variables produced by loop sub-graph nodes from previous iterations.
|
||||||
|
|
||||||
|
Keeping stale variables causes a freshly created response coordinator in the
|
||||||
|
next iteration to fall back to outdated values when no stream chunks exist.
|
||||||
|
"""
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
for node_id in loop_node_ids:
|
||||||
|
variable_pool.remove([node_id])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
||||||
|
|
@ -281,7 +281,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
|
|
||||||
text = invoke_result.message.content or ""
|
text = invoke_result.message.get_text_content()
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class AliyunOssStorage(BaseStorage):
|
||||||
self.bucket_name,
|
self.bucket_name,
|
||||||
connect_timeout=30,
|
connect_timeout=30,
|
||||||
region=region,
|
region=region,
|
||||||
|
cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, filename, data):
|
def save(self, filename, data):
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ class HuaweiObsStorage(BaseStorage):
|
||||||
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
|
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
|
||||||
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
|
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
|
||||||
server=dify_config.HUAWEI_OBS_SERVER,
|
server=dify_config.HUAWEI_OBS_SERVER,
|
||||||
|
path_style=dify_config.HUAWEI_OBS_PATH_STYLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, filename, data):
|
def save(self, filename, data):
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,14 @@ class TencentCosStorage(BaseStorage):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
|
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
|
||||||
|
if dify_config.TENCENT_COS_CUSTOM_DOMAIN:
|
||||||
|
config = CosConfig(
|
||||||
|
Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN,
|
||||||
|
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||||
|
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||||
|
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||||
|
)
|
||||||
|
else:
|
||||||
config = CosConfig(
|
config = CosConfig(
|
||||||
Region=dify_config.TENCENT_COS_REGION,
|
Region=dify_config.TENCENT_COS_REGION,
|
||||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
version = "1.11.1"
|
version = "1.11.2"
|
||||||
requires-python = ">=3.11,<3.13"
|
requires-python = ">=3.11,<3.13"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|
@ -69,7 +69,7 @@ dependencies = [
|
||||||
"pydantic-extra-types~=2.10.3",
|
"pydantic-extra-types~=2.10.3",
|
||||||
"pydantic-settings~=2.11.0",
|
"pydantic-settings~=2.11.0",
|
||||||
"pyjwt~=2.10.1",
|
"pyjwt~=2.10.1",
|
||||||
"pypdfium2==4.30.0",
|
"pypdfium2==5.2.0",
|
||||||
"python-docx~=1.1.0",
|
"python-docx~=1.1.0",
|
||||||
"python-dotenv==1.0.1",
|
"python-dotenv==1.0.1",
|
||||||
"pyyaml~=6.0.1",
|
"pyyaml~=6.0.1",
|
||||||
|
|
|
||||||
|
|
@ -155,6 +155,7 @@ class AppDslService:
|
||||||
parsed_url.scheme == "https"
|
parsed_url.scheme == "https"
|
||||||
and parsed_url.netloc == "github.com"
|
and parsed_url.netloc == "github.com"
|
||||||
and parsed_url.path.endswith((".yml", ".yaml"))
|
and parsed_url.path.endswith((".yml", ".yaml"))
|
||||||
|
and "/blob/" in parsed_url.path
|
||||||
):
|
):
|
||||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||||
yaml_url = yaml_url.replace("/blob/", "/")
|
yaml_url = yaml_url.replace("/blob/", "/")
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@ from enums.quota_type import QuotaType, unlimited
|
||||||
from extensions.otel import AppGenerateHandler, trace_span
|
from extensions.otel import AppGenerateHandler, trace_span
|
||||||
from models.model import Account, App, AppMode, EndUser
|
from models.model import Account, App, AppMode, EndUser
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||||
|
from services.errors.llm import InvokeRateLimitError
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from models.model import App, EndUser
|
||||||
from models.trigger import WorkflowTriggerLog
|
from models.trigger import WorkflowTriggerLog
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
|
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
@ -141,7 +141,7 @@ class AsyncWorkflowService:
|
||||||
trigger_log_repo.update(trigger_log)
|
trigger_log_repo.update(trigger_log)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
raise InvokeRateLimitError(
|
raise WorkflowQuotaLimitError(
|
||||||
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
|
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
||||||
"limit": 1,
|
"limit": 1,
|
||||||
"scrapeOptions": {"onlyMainContent": True},
|
"scrapeOptions": {"onlyMainContent": True},
|
||||||
}
|
}
|
||||||
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
|
response = self._post_request(self._build_url("v1/crawl"), options, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|
@ -35,15 +35,17 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
||||||
def _prepare_headers(self):
|
def _prepare_headers(self):
|
||||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
|
|
||||||
|
def _build_url(self, path: str) -> str:
|
||||||
|
# ensure exactly one slash between base and path, regardless of user-provided base_url
|
||||||
|
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
|
||||||
|
|
||||||
def _post_request(self, url, data, headers):
|
def _post_request(self, url, data, headers):
|
||||||
return httpx.post(url, headers=headers, json=data)
|
return httpx.post(url, headers=headers, json=data)
|
||||||
|
|
||||||
def _handle_error(self, response):
|
def _handle_error(self, response):
|
||||||
if response.status_code in {402, 409, 500}:
|
try:
|
||||||
error_message = response.json().get("error", "Unknown error occurred")
|
payload = response.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
payload = {}
|
||||||
|
error_message = payload.get("error") or payload.get("message") or (response.text or "Unknown error occurred")
|
||||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||||
else:
|
|
||||||
if response.text:
|
|
||||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
|
||||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
|
||||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ from typing import Any, Union
|
||||||
from sqlalchemy import asc, desc, func, or_, select
|
from sqlalchemy import asc, desc, func, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||||
|
|
@ -202,6 +204,7 @@ class ConversationService:
|
||||||
user: Union[Account, EndUser] | None,
|
user: Union[Account, EndUser] | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
last_id: str | None,
|
last_id: str | None,
|
||||||
|
variable_name: str | None = None,
|
||||||
) -> InfiniteScrollPagination:
|
) -> InfiniteScrollPagination:
|
||||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||||
|
|
||||||
|
|
@ -212,7 +215,25 @@ class ConversationService:
|
||||||
.order_by(ConversationVariable.created_at)
|
.order_by(ConversationVariable.created_at)
|
||||||
)
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
# Apply variable_name filter if provided
|
||||||
|
if variable_name:
|
||||||
|
# Filter using JSON extraction to match variable names case-insensitively
|
||||||
|
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
|
# Filter using JSON extraction to match variable names case-insensitively
|
||||||
|
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
|
||||||
|
stmt = stmt.where(
|
||||||
|
func.json_extract(ConversationVariable.data, "$.name").ilike(
|
||||||
|
f"%{escaped_variable_name}%", escape="\\"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif dify_config.DB_TYPE == "postgresql":
|
||||||
|
stmt = stmt.where(
|
||||||
|
func.json_extract_path_text(ConversationVariable.data, "name").ilike(
|
||||||
|
f"%{escaped_variable_name}%", escape="\\"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
if last_id:
|
if last_id:
|
||||||
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
|
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
|
||||||
if not last_variable:
|
if not last_variable:
|
||||||
|
|
@ -279,7 +300,7 @@ class ConversationService:
|
||||||
.where(ConversationVariable.id == variable_id)
|
.where(ConversationVariable.id == variable_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with session_factory.create_session() as session:
|
||||||
existing_variable = session.scalar(stmt)
|
existing_variable = session.scalar(stmt)
|
||||||
if not existing_variable:
|
if not existing_variable:
|
||||||
raise ConversationVariableNotExistsError()
|
raise ConversationVariableNotExistsError()
|
||||||
|
|
|
||||||
|
|
@ -3458,7 +3458,7 @@ class SegmentService:
|
||||||
if keyword:
|
if keyword:
|
||||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||||
|
|
||||||
query = query.order_by(DocumentSegment.position.asc())
|
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
|
||||||
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
|
|
||||||
return paginated_segments.items, paginated_segments.total
|
return paginated_segments.items, paginated_segments.total
|
||||||
|
|
|
||||||
|
|
@ -110,5 +110,5 @@ class EnterpriseService:
|
||||||
if not app_id:
|
if not app_id:
|
||||||
raise ValueError("app_id must be provided.")
|
raise ValueError("app_id must be provided.")
|
||||||
|
|
||||||
body = {"appId": app_id}
|
params = {"appId": app_id}
|
||||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvokeRateLimitError(Exception):
|
class WorkflowQuotaLimitError(Exception):
|
||||||
"""Raised when rate limit is exceeded for workflow invocations."""
|
"""Raised when workflow execution quota is exceeded (for async/background workflows)."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -105,3 +105,49 @@ class PluginParameterService:
|
||||||
)
|
)
|
||||||
.options
|
.options
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dynamic_select_options_with_credentials(
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
action: str,
|
||||||
|
parameter: str,
|
||||||
|
credential_id: str,
|
||||||
|
credentials: Mapping[str, Any],
|
||||||
|
) -> Sequence[PluginParameterOption]:
|
||||||
|
"""
|
||||||
|
Get dynamic select options using provided credentials directly.
|
||||||
|
Used for edit mode when credentials have been modified but not yet saved.
|
||||||
|
|
||||||
|
Security: credential_id is validated against tenant_id to ensure
|
||||||
|
users can only access their own credentials.
|
||||||
|
"""
|
||||||
|
from constants import HIDDEN_VALUE
|
||||||
|
|
||||||
|
# Get original subscription to replace hidden values (with tenant_id check for security)
|
||||||
|
original_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
|
||||||
|
if not original_subscription:
|
||||||
|
raise ValueError(f"Subscription {credential_id} not found")
|
||||||
|
|
||||||
|
# Replace [__HIDDEN__] with original values
|
||||||
|
resolved_credentials: dict[str, Any] = {
|
||||||
|
key: (original_subscription.credentials.get(key) if value == HIDDEN_VALUE else value)
|
||||||
|
for key, value in credentials.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
DynamicSelectClient()
|
||||||
|
.fetch_dynamic_select_options(
|
||||||
|
tenant_id,
|
||||||
|
user_id,
|
||||||
|
plugin_id,
|
||||||
|
provider,
|
||||||
|
action,
|
||||||
|
resolved_credentials,
|
||||||
|
original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
|
||||||
|
parameter,
|
||||||
|
)
|
||||||
|
.options
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -286,12 +286,12 @@ class BuiltinToolManageService:
|
||||||
|
|
||||||
session.add(db_provider)
|
session.add(db_provider)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
|
# Invalidate tool providers cache
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -319,8 +319,14 @@ class MCPToolManageService:
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||||
|
|
||||||
# Update database with retrieved tools
|
# Update database with retrieved tools (ensure description is a non-null string)
|
||||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
tools_payload = []
|
||||||
|
for tool in tools:
|
||||||
|
data = tool.model_dump()
|
||||||
|
if data.get("description") is None:
|
||||||
|
data["description"] = ""
|
||||||
|
tools_payload.append(data)
|
||||||
|
db_provider.tools = json.dumps(tools_payload)
|
||||||
db_provider.authed = True
|
db_provider.authed = True
|
||||||
db_provider.updated_at = datetime.now()
|
db_provider.updated_at = datetime.now()
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
|
@ -620,6 +626,21 @@ class MCPToolManageService:
|
||||||
server_url_hash=new_server_url_hash,
|
server_url_hash=new_server_url_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reconnect_with_url(
|
||||||
|
*,
|
||||||
|
server_url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
timeout: float | None,
|
||||||
|
sse_read_timeout: float | None,
|
||||||
|
) -> ReconnectResult:
|
||||||
|
return MCPToolManageService._reconnect_with_url(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
sse_read_timeout=sse_read_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reconnect_with_url(
|
def _reconnect_with_url(
|
||||||
*,
|
*,
|
||||||
|
|
@ -642,9 +663,16 @@ class MCPToolManageService:
|
||||||
sse_read_timeout=sse_read_timeout,
|
sse_read_timeout=sse_read_timeout,
|
||||||
) as mcp_client:
|
) as mcp_client:
|
||||||
tools = mcp_client.list_tools()
|
tools = mcp_client.list_tools()
|
||||||
|
# Ensure tool descriptions are non-null in payload
|
||||||
|
tools_payload = []
|
||||||
|
for t in tools:
|
||||||
|
d = t.model_dump()
|
||||||
|
if d.get("description") is None:
|
||||||
|
d["description"] = ""
|
||||||
|
tools_payload.append(d)
|
||||||
return ReconnectResult(
|
return ReconnectResult(
|
||||||
authed=True,
|
authed=True,
|
||||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
tools=json.dumps(tools_payload),
|
||||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||||
)
|
)
|
||||||
except MCPAuthError:
|
except MCPAuthError:
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@ from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import or_, select
|
from sqlalchemy import or_, select
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
|
|
@ -68,7 +68,6 @@ class WorkflowToolManageService:
|
||||||
if workflow is None:
|
if workflow is None:
|
||||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
|
||||||
workflow_tool_provider = WorkflowToolProvider(
|
workflow_tool_provider = WorkflowToolProvider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -81,13 +80,15 @@ class WorkflowToolManageService:
|
||||||
privacy_policy=privacy_policy,
|
privacy_policy=privacy_policy,
|
||||||
version=workflow.version,
|
version=workflow.version,
|
||||||
)
|
)
|
||||||
session.add(workflow_tool_provider)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
session.add(workflow_tool_provider)
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
ToolLabelManager.update_tool_labels(
|
ToolLabelManager.update_tool_labels(
|
||||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||||
|
|
|
||||||
|
|
@ -94,16 +94,23 @@ class TriggerProviderService:
|
||||||
|
|
||||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
for subscription in subscriptions:
|
for subscription in subscriptions:
|
||||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
controller=provider_controller,
|
controller=provider_controller,
|
||||||
subscription=subscription,
|
subscription=subscription,
|
||||||
)
|
)
|
||||||
subscription.credentials = dict(
|
subscription.credentials = dict(
|
||||||
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
|
credential_encrypter.mask_credentials(dict(credential_encrypter.decrypt(subscription.credentials)))
|
||||||
)
|
)
|
||||||
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
|
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||||
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
|
tenant_id=tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
subscription.properties = dict(
|
||||||
|
properties_encrypter.mask_credentials(dict(properties_encrypter.decrypt(subscription.properties)))
|
||||||
|
)
|
||||||
|
subscription.parameters = dict(subscription.parameters)
|
||||||
count = workflows_in_use_map.get(subscription.id)
|
count = workflows_in_use_map.get(subscription.id)
|
||||||
subscription.workflows_in_use = count if count is not None else 0
|
subscription.workflows_in_use = count if count is not None else 0
|
||||||
|
|
||||||
|
|
@ -209,6 +216,101 @@ class TriggerProviderService:
|
||||||
logger.exception("Failed to add trigger provider")
|
logger.exception("Failed to add trigger provider")
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_trigger_subscription(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
subscription_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
properties: Mapping[str, Any] | None = None,
|
||||||
|
parameters: Mapping[str, Any] | None = None,
|
||||||
|
credentials: Mapping[str, Any] | None = None,
|
||||||
|
credential_expires_at: int | None = None,
|
||||||
|
expires_at: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing trigger subscription.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param subscription_id: Subscription instance ID
|
||||||
|
:param name: Optional new name for this subscription
|
||||||
|
:param properties: Optional new properties
|
||||||
|
:param parameters: Optional new parameters
|
||||||
|
:param credentials: Optional new credentials
|
||||||
|
:param credential_expires_at: Optional new credential expiration timestamp
|
||||||
|
:param expires_at: Optional new expiration timestamp
|
||||||
|
:return: Success response with updated subscription info
|
||||||
|
"""
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
# Use distributed lock to prevent race conditions on the same subscription
|
||||||
|
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
|
||||||
|
with redis_client.lock(lock_key, timeout=20):
|
||||||
|
subscription: TriggerSubscription | None = (
|
||||||
|
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise ValueError(f"Trigger subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID(subscription.provider_id)
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
|
||||||
|
# Check for name uniqueness if name is being updated
|
||||||
|
if name is not None and name != subscription.name:
|
||||||
|
existing = (
|
||||||
|
session.query(TriggerSubscription)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||||
|
subscription.name = name
|
||||||
|
|
||||||
|
# Update properties if provided
|
||||||
|
if properties is not None:
|
||||||
|
properties_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_properties_schema(),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
# Handle hidden values - preserve original encrypted values
|
||||||
|
original_properties = properties_encrypter.decrypt(subscription.properties)
|
||||||
|
new_properties: dict[str, Any] = {
|
||||||
|
key: value if value != HIDDEN_VALUE else original_properties.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in properties.items()
|
||||||
|
}
|
||||||
|
subscription.properties = dict(properties_encrypter.encrypt(new_properties))
|
||||||
|
|
||||||
|
# Update parameters if provided
|
||||||
|
if parameters is not None:
|
||||||
|
subscription.parameters = dict(parameters)
|
||||||
|
|
||||||
|
# Update credentials if provided
|
||||||
|
if credentials is not None:
|
||||||
|
credential_type = CredentialType.of(subscription.credential_type)
|
||||||
|
credential_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_credential_schema_config(credential_type),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
subscription.credentials = dict(credential_encrypter.encrypt(dict(credentials)))
|
||||||
|
|
||||||
|
# Update credential expiration timestamp if provided
|
||||||
|
if credential_expires_at is not None:
|
||||||
|
subscription.credential_expires_at = credential_expires_at
|
||||||
|
|
||||||
|
# Update expiration timestamp if provided
|
||||||
|
if expires_at is not None:
|
||||||
|
subscription.expires_at = expires_at
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Clear subscription cache
|
||||||
|
delete_cache_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=subscription.provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
|
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -257,8 +359,6 @@ class TriggerProviderService:
|
||||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||||
|
|
||||||
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
|
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
|
||||||
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
|
|
||||||
if is_auto_created:
|
|
||||||
provider_id = TriggerProviderID(subscription.provider_id)
|
provider_id = TriggerProviderID(subscription.provider_id)
|
||||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||||
tenant_id=tenant_id, provider_id=provider_id
|
tenant_id=tenant_id, provider_id=provider_id
|
||||||
|
|
@ -268,6 +368,9 @@ class TriggerProviderService:
|
||||||
controller=provider_controller,
|
controller=provider_controller,
|
||||||
subscription=subscription,
|
subscription=subscription,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
|
||||||
|
if is_auto_created:
|
||||||
try:
|
try:
|
||||||
TriggerManager.unsubscribe_trigger(
|
TriggerManager.unsubscribe_trigger(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
@ -280,8 +383,8 @@ class TriggerProviderService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error unsubscribing trigger", exc_info=e)
|
logger.exception("Error unsubscribing trigger", exc_info=e)
|
||||||
|
|
||||||
# Clear cache
|
|
||||||
session.delete(subscription)
|
session.delete(subscription)
|
||||||
|
# Clear cache
|
||||||
delete_cache_for_subscription(
|
delete_cache_for_subscription(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=subscription.provider_id,
|
provider_id=subscription.provider_id,
|
||||||
|
|
@ -688,3 +791,188 @@ class TriggerProviderService:
|
||||||
)
|
)
|
||||||
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
||||||
return subscription
|
return subscription
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def verify_subscription_credentials(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
subscription_id: str,
|
||||||
|
credentials: Mapping[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Verify credentials for an existing subscription without updating it.
|
||||||
|
|
||||||
|
This is used in edit mode to validate new credentials before rebuild.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param user_id: User ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:param subscription_id: Subscription ID
|
||||||
|
:param credentials: New credentials to verify
|
||||||
|
:return: dict with 'verified' boolean
|
||||||
|
"""
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
subscription = cls.get_subscription_by_id(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise ValueError(f"Subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
credential_type = CredentialType.of(subscription.credential_type)
|
||||||
|
|
||||||
|
# For API Key, validate the new credentials
|
||||||
|
if credential_type == CredentialType.API_KEY:
|
||||||
|
new_credentials: dict[str, Any] = {
|
||||||
|
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in credentials.items()
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
provider_controller.validate_credentials(user_id, credentials=new_credentials)
|
||||||
|
return {"verified": True}
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid credentials: {e}") from e
|
||||||
|
|
||||||
|
return {"verified": True}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def rebuild_trigger_subscription(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
subscription_id: str,
|
||||||
|
credentials: Mapping[str, Any],
|
||||||
|
parameters: Mapping[str, Any],
|
||||||
|
name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a subscription builder for rebuilding an existing subscription.
|
||||||
|
|
||||||
|
This method creates a builder pre-filled with data from the rebuild request,
|
||||||
|
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param name: Name for the subscription
|
||||||
|
:param subscription_id: Subscription ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:param credentials: Credentials for the subscription
|
||||||
|
:param parameters: Parameters for the subscription
|
||||||
|
:return: SubscriptionBuilderApiEntity
|
||||||
|
"""
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
# Use distributed lock to prevent race conditions on the same subscription
|
||||||
|
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
|
||||||
|
with redis_client.lock(lock_key, timeout=20):
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
try:
|
||||||
|
# Get subscription within the transaction
|
||||||
|
subscription: TriggerSubscription | None = (
|
||||||
|
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise ValueError(f"Subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
credential_type = CredentialType.of(subscription.credential_type)
|
||||||
|
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||||
|
raise ValueError("Credential type not supported for rebuild")
|
||||||
|
|
||||||
|
# Decrypt existing credentials for merging
|
||||||
|
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||||
|
|
||||||
|
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
|
||||||
|
merged_credentials: dict[str, Any] = {
|
||||||
|
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in credentials.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
user_id = subscription.user_id
|
||||||
|
|
||||||
|
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||||
|
|
||||||
|
# FALLBACK: If the update api is not implemented,
|
||||||
|
# delete the previous subscription and create a new one
|
||||||
|
|
||||||
|
# Unsubscribe the previous subscription (external call, but we'll handle errors)
|
||||||
|
try:
|
||||||
|
TriggerManager.unsubscribe_trigger(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription=subscription.to_entity(),
|
||||||
|
credentials=decrypted_credentials,
|
||||||
|
credential_type=credential_type,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
|
||||||
|
# Continue anyway - the subscription might already be deleted externally
|
||||||
|
|
||||||
|
# Create a new subscription with the same subscription_id and endpoint_id (external call)
|
||||||
|
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||||
|
parameters=parameters,
|
||||||
|
credentials=merged_credentials,
|
||||||
|
credential_type=credential_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the subscription in the same transaction
|
||||||
|
# Inline update logic to reuse the same session
|
||||||
|
if name is not None and name != subscription.name:
|
||||||
|
existing = (
|
||||||
|
session.query(TriggerSubscription)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing and existing.id != subscription.id:
|
||||||
|
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||||
|
subscription.name = name
|
||||||
|
|
||||||
|
# Update parameters
|
||||||
|
subscription.parameters = dict(parameters)
|
||||||
|
|
||||||
|
# Update credentials with merged (and encrypted) values
|
||||||
|
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
|
||||||
|
|
||||||
|
# Update properties
|
||||||
|
if new_subscription.properties:
|
||||||
|
properties_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_properties_schema(),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
|
||||||
|
|
||||||
|
# Update expiration timestamp
|
||||||
|
if new_subscription.expires_at is not None:
|
||||||
|
subscription.expires_at = new_subscription.expires_at
|
||||||
|
|
||||||
|
# Commit the transaction
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Clear subscription cache
|
||||||
|
delete_cache_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=subscription.provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Rollback on any error
|
||||||
|
session.rollback()
|
||||||
|
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue