mirror of https://github.com/langgenius/dify.git
Compare commits
89 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
9b6b2f3195 | |
|
|
ae43ad5cb6 | |
|
|
5b02e5dcb6 | |
|
|
e3ef33366d | |
|
|
ee1d0df927 | |
|
|
184077c37c | |
|
|
3015e9be73 | |
|
|
2bb1e24fb4 | |
|
|
cad7101534 | |
|
|
e856287b65 | |
|
|
27be89c984 | |
|
|
fa69cce1e7 | |
|
|
f28a08a696 | |
|
|
8129b04143 | |
|
|
1b8e80a722 | |
|
|
0421387672 | |
|
|
2aaaa4bd34 | |
|
|
64dc98e607 | |
|
|
9007109a6b | |
|
|
925168383b | |
|
|
e6f3528bb0 | |
|
|
fb5edd0bf6 | |
|
|
de53c78125 | |
|
|
3a59ae9617 | |
|
|
69589807fd | |
|
|
6ca44eea28 | |
|
|
bf76f10653 | |
|
|
c1af6a7127 | |
|
|
1873b5a766 | |
|
|
9fbc7fa379 | |
|
|
2399d00d86 | |
|
|
3505516e8e | |
|
|
faef04cdf7 | |
|
|
0ba9b9e6b5 | |
|
|
30dd50ff83 | |
|
|
5338cf85b1 | |
|
|
673209d086 | |
|
|
43758ec85d | |
|
|
20944e7e1a | |
|
|
7a5d2728a1 | |
|
|
14bff10201 | |
|
|
9a6b4147bc | |
|
|
2c919efa69 | |
|
|
6d0e36479b | |
|
|
09be869f58 | |
|
|
0b1439fee4 | |
|
|
dfd2dd5c68 | |
|
|
3ae7788933 | |
|
|
446df6b50d | |
|
|
d9cecabe93 | |
|
|
b71a0d3f04 | |
|
|
d546d525b4 | |
|
|
a46dc2f37e | |
|
|
8b38e3f79d | |
|
|
44ab8a3376 | |
|
|
1e86535c4a | |
|
|
5b1c08c19c | |
|
|
6202c566e9 | |
|
|
a00ac1b5b1 | |
|
|
bf56c2e9db | |
|
|
543ce38a6c | |
|
|
1f2c85c916 | |
|
|
2b01f85d61 | |
|
|
d8010a7fbc | |
|
|
b067ad2f0a | |
|
|
b85564cae5 | |
|
|
c393d7a2dc | |
|
|
f610f6895f | |
|
|
d20a8d5b77 | |
|
|
8611301722 | |
|
|
6044f0666a | |
|
|
8d26e6ab28 | |
|
|
61d255a6e6 | |
|
|
f0d02b4b91 | |
|
|
d100354851 | |
|
|
93d1b2fc32 | |
|
|
fa1009b938 | |
|
|
fd64156f9d | |
|
|
bdd8a35b9d | |
|
|
b892906d71 | |
|
|
7e06225ce2 | |
|
|
f08d847c20 | |
|
|
44fc0c614c | |
|
|
0f3ffbee2c | |
|
|
08d5eee993 | |
|
|
9885e92854 | |
|
|
f2555b0bb1 | |
|
|
c3bb95d71d | |
|
|
996c7d9e16 |
|
|
@ -0,0 +1,8 @@
|
||||||
|
{
|
||||||
|
"enabledPlugins": {
|
||||||
|
"feature-dev@claude-plugins-official": true,
|
||||||
|
"context7@claude-plugins-official": true,
|
||||||
|
"typescript-lsp@claude-plugins-official": true,
|
||||||
|
"pyright-lsp@claude-plugins-official": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
{
|
|
||||||
"permissions": {
|
|
||||||
"allow": [],
|
|
||||||
"deny": []
|
|
||||||
},
|
|
||||||
"env": {
|
|
||||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
|
||||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
|
||||||
},
|
|
||||||
"enabledMcpjsonServers": [
|
|
||||||
"context7",
|
|
||||||
"sequential-thinking",
|
|
||||||
"github",
|
|
||||||
"fetch",
|
|
||||||
"playwright",
|
|
||||||
"ide"
|
|
||||||
],
|
|
||||||
"enableAllProjectMcpServers": true
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,483 @@
|
||||||
|
---
|
||||||
|
name: component-refactoring
|
||||||
|
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Dify Component Refactoring Skill
|
||||||
|
|
||||||
|
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
|
||||||
|
|
||||||
|
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
|
||||||
|
|
||||||
|
## Quick Reference
|
||||||
|
|
||||||
|
### Commands (run from `web/`)
|
||||||
|
|
||||||
|
Use paths relative to `web/` (e.g., `app/components/...`).
|
||||||
|
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd web
|
||||||
|
|
||||||
|
# Generate refactoring prompt
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
|
||||||
|
# Output refactoring analysis as JSON
|
||||||
|
pnpm refactor-component <path> --json
|
||||||
|
|
||||||
|
# Generate testing prompt (after refactoring)
|
||||||
|
pnpm analyze-component <path>
|
||||||
|
|
||||||
|
# Output testing analysis as JSON
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complexity Analysis
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Analyze component complexity
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
|
||||||
|
# Key metrics to check:
|
||||||
|
# - complexity: normalized score 0-100 (target < 50)
|
||||||
|
# - maxComplexity: highest single function complexity
|
||||||
|
# - lineCount: total lines (target < 300)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complexity Score Interpretation
|
||||||
|
|
||||||
|
| Score | Level | Action |
|
||||||
|
|-------|-------|--------|
|
||||||
|
| 0-25 | 🟢 Simple | Ready for testing |
|
||||||
|
| 26-50 | 🟡 Medium | Consider minor refactoring |
|
||||||
|
| 51-75 | 🟠 Complex | **Refactor before testing** |
|
||||||
|
| 76-100 | 🔴 Very Complex | **Must refactor** |
|
||||||
|
|
||||||
|
## Core Refactoring Patterns
|
||||||
|
|
||||||
|
### Pattern 1: Extract Custom Hooks
|
||||||
|
|
||||||
|
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
|
||||||
|
|
||||||
|
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Complex state logic in component
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
// 50+ lines of state management logic...
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to custom hook
|
||||||
|
// hooks/use-model-config.ts
|
||||||
|
export const useModelConfig = (appId: string) => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
// Related state management logic here
|
||||||
|
|
||||||
|
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const { modelConfig, setModelConfig } = useModelConfig(appId)
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
|
||||||
|
- `web/app/components/app/configuration/debug/hooks.tsx`
|
||||||
|
- `web/app/components/workflow/hooks/use-workflow.ts`
|
||||||
|
|
||||||
|
### Pattern 2: Extract Sub-Components
|
||||||
|
|
||||||
|
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
|
||||||
|
|
||||||
|
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Monolithic JSX with multiple sections
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* 100 lines of header UI */}
|
||||||
|
{/* 100 lines of operations UI */}
|
||||||
|
{/* 100 lines of modals */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Split into focused components
|
||||||
|
// app-info/
|
||||||
|
// ├── index.tsx (orchestration only)
|
||||||
|
// ├── app-header.tsx (header UI)
|
||||||
|
// ├── app-operations.tsx (operations UI)
|
||||||
|
// └── app-modals.tsx (modal management)
|
||||||
|
|
||||||
|
const AppInfo = () => {
|
||||||
|
const { showModal, setShowModal } = useAppInfoModals()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<AppHeader appDetail={appDetail} />
|
||||||
|
<AppOperations onAction={handleAction} />
|
||||||
|
<AppModals show={showModal} onClose={() => setShowModal(null)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/app/components/app/configuration/` directory structure
|
||||||
|
- `web/app/components/workflow/nodes/` per-node organization
|
||||||
|
|
||||||
|
### Pattern 3: Simplify Conditional Logic
|
||||||
|
|
||||||
|
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Deeply nested conditionals
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateChatZh />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateChatJa />
|
||||||
|
default:
|
||||||
|
return <TemplateChatEn />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||||
|
// Another 15 lines...
|
||||||
|
}
|
||||||
|
// More conditions...
|
||||||
|
}, [appDetail, locale])
|
||||||
|
|
||||||
|
// ✅ After: Use lookup tables + early returns
|
||||||
|
const TEMPLATE_MAP = {
|
||||||
|
[AppModeEnum.CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateChatJa,
|
||||||
|
default: TemplateChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.ADVANCED_CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||||
|
// ...
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
|
||||||
|
if (!modeTemplates) return null
|
||||||
|
|
||||||
|
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
|
||||||
|
return <TemplateComponent appDetail={appDetail} />
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 4: Extract API/Data Logic
|
||||||
|
|
||||||
|
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||||
|
|
||||||
|
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: API logic in component
|
||||||
|
const MCPServiceCard = () => {
|
||||||
|
const [basicAppConfig, setBasicAppConfig] = useState({})
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isBasicApp && appId) {
|
||||||
|
(async () => {
|
||||||
|
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||||
|
setBasicAppConfig(res?.model_config || {})
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
}, [appId, isBasicApp])
|
||||||
|
|
||||||
|
// More API-related logic...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to data hook using React Query
|
||||||
|
// use-app-config.ts
|
||||||
|
import { useQuery } from '@tanstack/react-query'
|
||||||
|
import { get } from '@/service/base'
|
||||||
|
|
||||||
|
const NAME_SPACE = 'appConfig'
|
||||||
|
|
||||||
|
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
|
||||||
|
return useQuery({
|
||||||
|
enabled: isBasicApp && !!appId,
|
||||||
|
queryKey: [NAME_SPACE, 'detail', appId],
|
||||||
|
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||||
|
select: data => data?.model_config || {},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const MCPServiceCard = () => {
|
||||||
|
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
|
||||||
|
// UI only
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**React Query Best Practices in Dify**:
|
||||||
|
- Define `NAME_SPACE` for query key organization
|
||||||
|
- Use `enabled` option for conditional fetching
|
||||||
|
- Use `select` for data transformation
|
||||||
|
- Export invalidation hooks: `useInvalidXxx`
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/service/use-workflow.ts`
|
||||||
|
- `web/service/use-common.ts`
|
||||||
|
- `web/service/knowledge/use-dataset.ts`
|
||||||
|
- `web/service/knowledge/use-document.ts`
|
||||||
|
|
||||||
|
### Pattern 5: Extract Modal/Dialog Management
|
||||||
|
|
||||||
|
**When**: Component manages multiple modals with complex open/close states.
|
||||||
|
|
||||||
|
**Dify Convention**: Modals should be extracted with their state management.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Multiple modal states in component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const [showEditModal, setShowEditModal] = useState(false)
|
||||||
|
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
|
||||||
|
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
|
||||||
|
const [showSwitchModal, setShowSwitchModal] = useState(false)
|
||||||
|
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
|
||||||
|
// 5+ more modal states...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to modal management hook
|
||||||
|
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
|
||||||
|
|
||||||
|
const useAppInfoModals = () => {
|
||||||
|
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||||
|
|
||||||
|
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
|
||||||
|
const closeModal = useCallback(() => setActiveModal(null), [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
activeModal,
|
||||||
|
openModal,
|
||||||
|
closeModal,
|
||||||
|
isOpen: (type: ModalType) => activeModal === type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 6: Extract Form Logic
|
||||||
|
|
||||||
|
**When**: Complex form validation, submission handling, or field transformation.
|
||||||
|
|
||||||
|
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ✅ Use existing form infrastructure
|
||||||
|
import { useAppForm } from '@/app/components/base/form'
|
||||||
|
|
||||||
|
const ConfigForm = () => {
|
||||||
|
const form = useAppForm({
|
||||||
|
defaultValues: { name: '', description: '' },
|
||||||
|
onSubmit: handleSubmit,
|
||||||
|
})
|
||||||
|
|
||||||
|
return <form.Provider>...</form.Provider>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dify-Specific Refactoring Guidelines
|
||||||
|
|
||||||
|
### 1. Context Provider Extraction
|
||||||
|
|
||||||
|
**When**: Component provides complex context values with multiple states.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Large context value object
|
||||||
|
const value = {
|
||||||
|
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
|
||||||
|
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
|
||||||
|
// 50+ more properties...
|
||||||
|
}
|
||||||
|
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
|
||||||
|
|
||||||
|
// ✅ After: Split into domain-specific contexts
|
||||||
|
<ModelConfigProvider value={modelConfigValue}>
|
||||||
|
<DatasetConfigProvider value={datasetConfigValue}>
|
||||||
|
<UIConfigProvider value={uiConfigValue}>
|
||||||
|
{children}
|
||||||
|
</UIConfigProvider>
|
||||||
|
</DatasetConfigProvider>
|
||||||
|
</ModelConfigProvider>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Reference**: `web/context/` directory structure
|
||||||
|
|
||||||
|
### 2. Workflow Node Components
|
||||||
|
|
||||||
|
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Keep node logic in `use-interactions.ts`
|
||||||
|
- Extract panel UI to separate files
|
||||||
|
- Use `_base` components for common patterns
|
||||||
|
|
||||||
|
```
|
||||||
|
nodes/<node-type>/
|
||||||
|
├── index.tsx # Node registration
|
||||||
|
├── node.tsx # Node visual component
|
||||||
|
├── panel.tsx # Configuration panel
|
||||||
|
├── use-interactions.ts # Node-specific hooks
|
||||||
|
└── types.ts # Type definitions
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Configuration Components
|
||||||
|
|
||||||
|
**When**: Refactoring app configuration components.
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Separate config sections into subdirectories
|
||||||
|
- Use existing patterns from `web/app/components/app/configuration/`
|
||||||
|
- Keep feature toggles in dedicated components
|
||||||
|
|
||||||
|
### 4. Tool/Plugin Components
|
||||||
|
|
||||||
|
**When**: Refactoring tool-related components (`web/app/components/tools/`).
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Follow existing modal patterns
|
||||||
|
- Use service hooks from `web/service/use-tools.ts`
|
||||||
|
- Keep provider-specific logic isolated
|
||||||
|
|
||||||
|
## Refactoring Workflow
|
||||||
|
|
||||||
|
### Step 1: Generate Refactoring Prompt
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
```
|
||||||
|
|
||||||
|
This command will:
|
||||||
|
- Analyze component complexity and features
|
||||||
|
- Identify specific refactoring actions needed
|
||||||
|
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
|
||||||
|
- Provide detailed requirements based on detected patterns
|
||||||
|
|
||||||
|
### Step 2: Analyze Details
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
```
|
||||||
|
|
||||||
|
Identify:
|
||||||
|
- Total complexity score
|
||||||
|
- Max function complexity
|
||||||
|
- Line count
|
||||||
|
- Features detected (state, effects, API, etc.)
|
||||||
|
|
||||||
|
### Step 3: Plan
|
||||||
|
|
||||||
|
Create a refactoring plan based on detected features:
|
||||||
|
|
||||||
|
| Detected Feature | Refactoring Action |
|
||||||
|
|------------------|-------------------|
|
||||||
|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
|
||||||
|
| `hasAPI: true` | Extract data/service hook |
|
||||||
|
| `hasEvents: true` (many) | Extract event handlers |
|
||||||
|
| `lineCount > 300` | Split into sub-components |
|
||||||
|
| `maxComplexity > 50` | Simplify conditional logic |
|
||||||
|
|
||||||
|
### Step 4: Execute Incrementally
|
||||||
|
|
||||||
|
1. **Extract one piece at a time**
|
||||||
|
2. **Run lint, type-check, and tests after each extraction**
|
||||||
|
3. **Verify functionality before next step**
|
||||||
|
|
||||||
|
```
|
||||||
|
For each extraction:
|
||||||
|
┌────────────────────────────────────────┐
|
||||||
|
│ 1. Extract code │
|
||||||
|
│ 2. Run: pnpm lint:fix │
|
||||||
|
│ 3. Run: pnpm type-check:tsgo │
|
||||||
|
│ 4. Run: pnpm test │
|
||||||
|
│ 5. Test functionality manually │
|
||||||
|
│ 6. PASS? → Next extraction │
|
||||||
|
│ FAIL? → Fix before continuing │
|
||||||
|
└────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 5: Verify
|
||||||
|
|
||||||
|
After refactoring:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Re-run refactor command to verify improvements
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
|
||||||
|
# If complexity < 25 and lines < 200, you'll see:
|
||||||
|
# ✅ COMPONENT IS WELL-STRUCTURED
|
||||||
|
|
||||||
|
# For detailed metrics:
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
|
||||||
|
# Target metrics:
|
||||||
|
# - complexity < 50
|
||||||
|
# - lineCount < 300
|
||||||
|
# - maxComplexity < 30
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Mistakes to Avoid
|
||||||
|
|
||||||
|
### ❌ Over-Engineering
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Too many tiny hooks
|
||||||
|
const useButtonText = () => useState('Click')
|
||||||
|
const useButtonDisabled = () => useState(false)
|
||||||
|
const useButtonLoading = () => useState(false)
|
||||||
|
|
||||||
|
// ✅ Cohesive hook with related state
|
||||||
|
const useButtonState = () => {
|
||||||
|
const [text, setText] = useState('Click')
|
||||||
|
const [disabled, setDisabled] = useState(false)
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
return { text, setText, disabled, setDisabled, loading, setLoading }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### ❌ Breaking Existing Patterns
|
||||||
|
|
||||||
|
- Follow existing directory structures
|
||||||
|
- Maintain naming conventions
|
||||||
|
- Preserve export patterns for compatibility
|
||||||
|
|
||||||
|
### ❌ Premature Abstraction
|
||||||
|
|
||||||
|
- Only extract when there's clear complexity benefit
|
||||||
|
- Don't create abstractions for single-use code
|
||||||
|
- Keep refactored code in the same domain area
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
### Dify Codebase Examples
|
||||||
|
|
||||||
|
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
|
||||||
|
- **Component splitting**: `web/app/components/app/configuration/`
|
||||||
|
- **Service hooks**: `web/service/use-*.ts`
|
||||||
|
- **Workflow patterns**: `web/app/components/workflow/hooks/`
|
||||||
|
- **Form patterns**: `web/app/components/base/form/`
|
||||||
|
|
||||||
|
### Related Skills
|
||||||
|
|
||||||
|
- `frontend-testing` - For testing refactored components
|
||||||
|
- `web/testing/testing.md` - Testing specification
|
||||||
|
|
@ -0,0 +1,493 @@
|
||||||
|
# Complexity Reduction Patterns
|
||||||
|
|
||||||
|
This document provides patterns for reducing cognitive complexity in Dify React components.
|
||||||
|
|
||||||
|
## Understanding Complexity
|
||||||
|
|
||||||
|
### SonarJS Cognitive Complexity
|
||||||
|
|
||||||
|
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
|
||||||
|
|
||||||
|
- **Total Complexity**: Sum of all functions' complexity in the file
|
||||||
|
- **Max Complexity**: Highest single function complexity
|
||||||
|
|
||||||
|
### What Increases Complexity
|
||||||
|
|
||||||
|
| Pattern | Complexity Impact |
|
||||||
|
|---------|-------------------|
|
||||||
|
| `if/else` | +1 per branch |
|
||||||
|
| Nested conditions | +1 per nesting level |
|
||||||
|
| `switch/case` | +1 per case |
|
||||||
|
| `for/while/do` | +1 per loop |
|
||||||
|
| `&&`/`||` chains | +1 per operator |
|
||||||
|
| Nested callbacks | +1 per nesting level |
|
||||||
|
| `try/catch` | +1 per catch |
|
||||||
|
| Ternary expressions | +1 per nesting |
|
||||||
|
|
||||||
|
## Pattern 1: Replace Conditionals with Lookup Tables
|
||||||
|
|
||||||
|
**Before** (complexity: ~15):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateChatZh appDetail={appDetail} />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateChatJa appDetail={appDetail} />
|
||||||
|
default:
|
||||||
|
return <TemplateChatEn appDetail={appDetail} />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateAdvancedChatZh appDetail={appDetail} />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateAdvancedChatJa appDetail={appDetail} />
|
||||||
|
default:
|
||||||
|
return <TemplateAdvancedChatEn appDetail={appDetail} />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
|
||||||
|
// Similar pattern...
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~3):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Define lookup table outside component
|
||||||
|
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
|
||||||
|
[AppModeEnum.CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateChatJa,
|
||||||
|
default: TemplateChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.ADVANCED_CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
|
||||||
|
default: TemplateAdvancedChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.WORKFLOW]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateWorkflowZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateWorkflowJa,
|
||||||
|
default: TemplateWorkflowEn,
|
||||||
|
},
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean component logic
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (!appDetail?.mode) return null
|
||||||
|
|
||||||
|
const templates = TEMPLATE_MAP[appDetail.mode]
|
||||||
|
if (!templates) return null
|
||||||
|
|
||||||
|
const TemplateComponent = templates[locale] ?? templates.default
|
||||||
|
return <TemplateComponent appDetail={appDetail} />
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 2: Use Early Returns
|
||||||
|
|
||||||
|
**Before** (complexity: ~10):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const handleSubmit = () => {
|
||||||
|
if (isValid) {
|
||||||
|
if (hasChanges) {
|
||||||
|
if (isConnected) {
|
||||||
|
submitData()
|
||||||
|
} else {
|
||||||
|
showConnectionError()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showNoChangesMessage()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showValidationError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~4):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const handleSubmit = () => {
|
||||||
|
if (!isValid) {
|
||||||
|
showValidationError()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasChanges) {
|
||||||
|
showNoChangesMessage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isConnected) {
|
||||||
|
showConnectionError()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
submitData()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 3: Extract Complex Conditions
|
||||||
|
|
||||||
|
**Before** (complexity: high):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const canPublish = (() => {
|
||||||
|
if (mode !== AppModeEnum.COMPLETION) {
|
||||||
|
if (!isAdvancedMode)
|
||||||
|
return true
|
||||||
|
|
||||||
|
if (modelModeType === ModelModeType.completion) {
|
||||||
|
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
|
||||||
|
return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !promptEmpty
|
||||||
|
})()
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract to named functions
|
||||||
|
const canPublishInCompletionMode = () => !promptEmpty
|
||||||
|
|
||||||
|
const canPublishInChatMode = () => {
|
||||||
|
if (!isAdvancedMode) return true
|
||||||
|
if (modelModeType !== ModelModeType.completion) return true
|
||||||
|
return hasSetBlockStatus.history && hasSetBlockStatus.query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean main logic
|
||||||
|
const canPublish = mode === AppModeEnum.COMPLETION
|
||||||
|
? canPublishInCompletionMode()
|
||||||
|
: canPublishInChatMode()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 4: Replace Chained Ternaries
|
||||||
|
|
||||||
|
**Before** (complexity: ~5):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const statusText = serverActivated
|
||||||
|
? t('status.running')
|
||||||
|
: serverPublished
|
||||||
|
? t('status.inactive')
|
||||||
|
: appUnpublished
|
||||||
|
? t('status.unpublished')
|
||||||
|
: t('status.notConfigured')
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~2):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const getStatusText = () => {
|
||||||
|
if (serverActivated) return t('status.running')
|
||||||
|
if (serverPublished) return t('status.inactive')
|
||||||
|
if (appUnpublished) return t('status.unpublished')
|
||||||
|
return t('status.notConfigured')
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusText = getStatusText()
|
||||||
|
```
|
||||||
|
|
||||||
|
Or use lookup:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const STATUS_TEXT_MAP = {
|
||||||
|
running: 'status.running',
|
||||||
|
inactive: 'status.inactive',
|
||||||
|
unpublished: 'status.unpublished',
|
||||||
|
notConfigured: 'status.notConfigured',
|
||||||
|
} as const
|
||||||
|
|
||||||
|
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
|
||||||
|
if (serverActivated) return 'running'
|
||||||
|
if (serverPublished) return 'inactive'
|
||||||
|
if (appUnpublished) return 'unpublished'
|
||||||
|
return 'notConfigured'
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 5: Flatten Nested Loops
|
||||||
|
|
||||||
|
**Before** (complexity: high):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const processData = (items: Item[]) => {
|
||||||
|
const results: ProcessedItem[] = []
|
||||||
|
|
||||||
|
for (const item of items) {
|
||||||
|
if (item.isValid) {
|
||||||
|
for (const child of item.children) {
|
||||||
|
if (child.isActive) {
|
||||||
|
for (const prop of child.properties) {
|
||||||
|
if (prop.value !== null) {
|
||||||
|
results.push({
|
||||||
|
itemId: item.id,
|
||||||
|
childId: child.id,
|
||||||
|
propValue: prop.value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Use functional approach
|
||||||
|
const processData = (items: Item[]) => {
|
||||||
|
return items
|
||||||
|
.filter(item => item.isValid)
|
||||||
|
.flatMap(item =>
|
||||||
|
item.children
|
||||||
|
.filter(child => child.isActive)
|
||||||
|
.flatMap(child =>
|
||||||
|
child.properties
|
||||||
|
.filter(prop => prop.value !== null)
|
||||||
|
.map(prop => ({
|
||||||
|
itemId: item.id,
|
||||||
|
childId: child.id,
|
||||||
|
propValue: prop.value,
|
||||||
|
}))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 6: Extract Event Handler Logic
|
||||||
|
|
||||||
|
**Before** (complexity: high in component):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const Component = () => {
|
||||||
|
const handleSelect = (data: DataSet[]) => {
|
||||||
|
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
|
||||||
|
hideSelectDataSet()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
formattingChangedDispatcher()
|
||||||
|
let newDatasets = data
|
||||||
|
if (data.find(item => !item.name)) {
|
||||||
|
const newSelected = produce(data, (draft) => {
|
||||||
|
data.forEach((item, index) => {
|
||||||
|
if (!item.name) {
|
||||||
|
const newItem = dataSets.find(i => i.id === item.id)
|
||||||
|
if (newItem)
|
||||||
|
draft[index] = newItem
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
setDataSets(newSelected)
|
||||||
|
newDatasets = newSelected
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
setDataSets(data)
|
||||||
|
}
|
||||||
|
hideSelectDataSet()
|
||||||
|
|
||||||
|
// 40 more lines of logic...
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract to hook or utility
|
||||||
|
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
|
||||||
|
const normalizeSelection = (data: DataSet[]) => {
|
||||||
|
const hasUnloadedItem = data.some(item => !item.name)
|
||||||
|
if (!hasUnloadedItem) return data
|
||||||
|
|
||||||
|
return produce(data, (draft) => {
|
||||||
|
data.forEach((item, index) => {
|
||||||
|
if (!item.name) {
|
||||||
|
const existing = dataSets.find(i => i.id === item.id)
|
||||||
|
if (existing) draft[index] = existing
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasSelectionChanged = (newData: DataSet[]) => {
|
||||||
|
return !isEqual(
|
||||||
|
newData.map(item => item.id),
|
||||||
|
dataSets.map(item => item.id)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return { normalizeSelection, hasSelectionChanged }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const Component = () => {
|
||||||
|
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
|
||||||
|
|
||||||
|
const handleSelect = (data: DataSet[]) => {
|
||||||
|
if (!hasSelectionChanged(data)) {
|
||||||
|
hideSelectDataSet()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
formattingChangedDispatcher()
|
||||||
|
const normalized = normalizeSelection(data)
|
||||||
|
setDataSets(normalized)
|
||||||
|
hideSelectDataSet()
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 7: Reduce Boolean Logic Complexity
|
||||||
|
|
||||||
|
**Before** (complexity: ~8):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const toggleDisabled = hasInsufficientPermissions
|
||||||
|
|| appUnpublished
|
||||||
|
|| missingStartNode
|
||||||
|
|| triggerModeDisabled
|
||||||
|
|| (isAdvancedApp && !currentWorkflow?.graph)
|
||||||
|
|| (isBasicApp && !basicAppConfig.updated_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~3):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract meaningful boolean functions
|
||||||
|
const isAppReady = () => {
|
||||||
|
if (isAdvancedApp) return !!currentWorkflow?.graph
|
||||||
|
return !!basicAppConfig.updated_at
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasRequiredPermissions = () => {
|
||||||
|
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
|
||||||
|
}
|
||||||
|
|
||||||
|
const canToggle = () => {
|
||||||
|
if (!hasRequiredPermissions()) return false
|
||||||
|
if (!isAppReady()) return false
|
||||||
|
if (missingStartNode) return false
|
||||||
|
if (triggerModeDisabled) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleDisabled = !canToggle()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 8: Simplify useMemo/useCallback Dependencies
|
||||||
|
|
||||||
|
**Before** (complexity: multiple recalculations):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const payload = useMemo(() => {
|
||||||
|
let parameters: Parameter[] = []
|
||||||
|
let outputParameters: OutputParameter[] = []
|
||||||
|
|
||||||
|
if (!published) {
|
||||||
|
parameters = (inputs || []).map((item) => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
form: 'llm',
|
||||||
|
required: item.required,
|
||||||
|
type: item.type,
|
||||||
|
}))
|
||||||
|
outputParameters = (outputs || []).map((item) => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
type: item.value_type,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
else if (detail && detail.tool) {
|
||||||
|
parameters = (inputs || []).map((item) => ({
|
||||||
|
// Complex transformation...
|
||||||
|
}))
|
||||||
|
outputParameters = (outputs || []).map((item) => ({
|
||||||
|
// Complex transformation...
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
icon: detail?.icon || icon,
|
||||||
|
label: detail?.label || name,
|
||||||
|
// ...more fields
|
||||||
|
}
|
||||||
|
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: separated concerns):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Separate transformations
|
||||||
|
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
|
||||||
|
return useMemo(() => {
|
||||||
|
if (!published) {
|
||||||
|
return inputs.map(item => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
form: 'llm',
|
||||||
|
required: item.required,
|
||||||
|
type: item.type,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!detail?.tool) return []
|
||||||
|
|
||||||
|
return inputs.map(item => ({
|
||||||
|
name: item.variable,
|
||||||
|
required: item.required,
|
||||||
|
type: item.type === 'paragraph' ? 'string' : item.type,
|
||||||
|
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
|
||||||
|
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
|
||||||
|
}))
|
||||||
|
}, [inputs, detail, published])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component uses hook
|
||||||
|
const parameters = useParameterTransform(inputs, detail, published)
|
||||||
|
const outputParameters = useOutputTransform(outputs, detail, published)
|
||||||
|
|
||||||
|
const payload = useMemo(() => ({
|
||||||
|
icon: detail?.icon || icon,
|
||||||
|
label: detail?.label || name,
|
||||||
|
parameters,
|
||||||
|
outputParameters,
|
||||||
|
// ...
|
||||||
|
}), [detail, icon, name, parameters, outputParameters])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Target Metrics After Refactoring
|
||||||
|
|
||||||
|
| Metric | Target |
|
||||||
|
|--------|--------|
|
||||||
|
| Total Complexity | < 50 |
|
||||||
|
| Max Function Complexity | < 30 |
|
||||||
|
| Function Length | < 30 lines |
|
||||||
|
| Nesting Depth | ≤ 3 levels |
|
||||||
|
| Conditional Chains | ≤ 3 conditions |
|
||||||
|
|
@ -0,0 +1,477 @@
|
||||||
|
# Component Splitting Patterns
|
||||||
|
|
||||||
|
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
|
||||||
|
|
||||||
|
## When to Split Components
|
||||||
|
|
||||||
|
Split a component when you identify:
|
||||||
|
|
||||||
|
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
|
||||||
|
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
|
||||||
|
1. **Repeated patterns** - Similar UI structures used multiple times
|
||||||
|
1. **300+ lines** - Component exceeds manageable size
|
||||||
|
1. **Modal clusters** - Multiple modals rendered in one component
|
||||||
|
|
||||||
|
## Splitting Strategies
|
||||||
|
|
||||||
|
### Strategy 1: Section-Based Splitting
|
||||||
|
|
||||||
|
Identify visual sections and extract each as a component.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Monolithic component (500+ lines)
|
||||||
|
const ConfigurationPage = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Header Section - 50 lines */}
|
||||||
|
<div className="header">
|
||||||
|
<h1>{t('configuration.title')}</h1>
|
||||||
|
<div className="actions">
|
||||||
|
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||||
|
<ModelParameterModal ... />
|
||||||
|
<AppPublisher ... />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Config Section - 200 lines */}
|
||||||
|
<div className="config">
|
||||||
|
<Config />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Debug Section - 150 lines */}
|
||||||
|
<div className="debug">
|
||||||
|
<Debug ... />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Modals Section - 100 lines */}
|
||||||
|
{showSelectDataSet && <SelectDataSet ... />}
|
||||||
|
{showHistoryModal && <EditHistoryModal ... />}
|
||||||
|
{showUseGPT4Confirm && <Confirm ... />}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Split into focused components
|
||||||
|
// configuration/
|
||||||
|
// ├── index.tsx (orchestration)
|
||||||
|
// ├── configuration-header.tsx
|
||||||
|
// ├── configuration-content.tsx
|
||||||
|
// ├── configuration-debug.tsx
|
||||||
|
// └── configuration-modals.tsx
|
||||||
|
|
||||||
|
// configuration-header.tsx
|
||||||
|
interface ConfigurationHeaderProps {
|
||||||
|
isAdvancedMode: boolean
|
||||||
|
onPublish: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
|
||||||
|
isAdvancedMode,
|
||||||
|
onPublish,
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="header">
|
||||||
|
<h1>{t('configuration.title')}</h1>
|
||||||
|
<div className="actions">
|
||||||
|
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||||
|
<ModelParameterModal ... />
|
||||||
|
<AppPublisher onPublish={onPublish} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// index.tsx (orchestration only)
|
||||||
|
const ConfigurationPage = () => {
|
||||||
|
const { modelConfig, setModelConfig } = useModelConfig()
|
||||||
|
const { activeModal, openModal, closeModal } = useModalState()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<ConfigurationHeader
|
||||||
|
isAdvancedMode={isAdvancedMode}
|
||||||
|
onPublish={handlePublish}
|
||||||
|
/>
|
||||||
|
<ConfigurationContent
|
||||||
|
modelConfig={modelConfig}
|
||||||
|
onConfigChange={setModelConfig}
|
||||||
|
/>
|
||||||
|
{!isMobile && (
|
||||||
|
<ConfigurationDebug
|
||||||
|
inputs={inputs}
|
||||||
|
onSetting={handleSetting}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<ConfigurationModals
|
||||||
|
activeModal={activeModal}
|
||||||
|
onClose={closeModal}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 2: Conditional Block Extraction
|
||||||
|
|
||||||
|
Extract large conditional rendering blocks.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Large conditional blocks
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{expand ? (
|
||||||
|
<div className="expanded">
|
||||||
|
{/* 100 lines of expanded view */}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="collapsed">
|
||||||
|
{/* 50 lines of collapsed view */}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Separate view components
|
||||||
|
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="expanded">
|
||||||
|
{/* Clean, focused expanded view */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="collapsed">
|
||||||
|
{/* Clean, focused collapsed view */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{expand
|
||||||
|
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
|
||||||
|
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 3: Modal Extraction
|
||||||
|
|
||||||
|
Extract modals with their trigger logic.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Multiple modals in one component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const [showEdit, setShowEdit] = useState(false)
|
||||||
|
const [showDuplicate, setShowDuplicate] = useState(false)
|
||||||
|
const [showDelete, setShowDelete] = useState(false)
|
||||||
|
const [showSwitch, setShowSwitch] = useState(false)
|
||||||
|
|
||||||
|
const onEdit = async (data) => { /* 20 lines */ }
|
||||||
|
const onDuplicate = async (data) => { /* 20 lines */ }
|
||||||
|
const onDelete = async () => { /* 15 lines */ }
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Main content */}
|
||||||
|
|
||||||
|
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
|
||||||
|
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
|
||||||
|
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
|
||||||
|
{showSwitch && <SwitchModal ... />}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Modal manager component
|
||||||
|
// app-info-modals.tsx
|
||||||
|
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
|
||||||
|
|
||||||
|
interface AppInfoModalsProps {
|
||||||
|
appDetail: AppDetail
|
||||||
|
activeModal: ModalType
|
||||||
|
onClose: () => void
|
||||||
|
onSuccess: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfoModals: FC<AppInfoModalsProps> = ({
|
||||||
|
appDetail,
|
||||||
|
activeModal,
|
||||||
|
onClose,
|
||||||
|
onSuccess,
|
||||||
|
}) => {
|
||||||
|
const handleEdit = async (data) => { /* logic */ }
|
||||||
|
const handleDuplicate = async (data) => { /* logic */ }
|
||||||
|
const handleDelete = async () => { /* logic */ }
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{activeModal === 'edit' && (
|
||||||
|
<EditModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onConfirm={handleEdit}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'duplicate' && (
|
||||||
|
<DuplicateModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onConfirm={handleDuplicate}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'delete' && (
|
||||||
|
<DeleteConfirm
|
||||||
|
onConfirm={handleDelete}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'switch' && (
|
||||||
|
<SwitchModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parent component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const { activeModal, openModal, closeModal } = useModalState()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Main content with openModal triggers */}
|
||||||
|
<Button onClick={() => openModal('edit')}>Edit</Button>
|
||||||
|
|
||||||
|
<AppInfoModals
|
||||||
|
appDetail={appDetail}
|
||||||
|
activeModal={activeModal}
|
||||||
|
onClose={closeModal}
|
||||||
|
onSuccess={handleSuccess}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 4: List Item Extraction
|
||||||
|
|
||||||
|
Extract repeated item rendering.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Inline item rendering
|
||||||
|
const OperationsList = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{operations.map(op => (
|
||||||
|
<div key={op.id} className="operation-item">
|
||||||
|
<span className="icon">{op.icon}</span>
|
||||||
|
<span className="title">{op.title}</span>
|
||||||
|
<span className="description">{op.description}</span>
|
||||||
|
<button onClick={() => op.onClick()}>
|
||||||
|
{op.actionLabel}
|
||||||
|
</button>
|
||||||
|
{op.badge && <Badge>{op.badge}</Badge>}
|
||||||
|
{/* More complex rendering... */}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extracted item component
|
||||||
|
interface OperationItemProps {
|
||||||
|
operation: Operation
|
||||||
|
onAction: (id: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="operation-item">
|
||||||
|
<span className="icon">{operation.icon}</span>
|
||||||
|
<span className="title">{operation.title}</span>
|
||||||
|
<span className="description">{operation.description}</span>
|
||||||
|
<button onClick={() => onAction(operation.id)}>
|
||||||
|
{operation.actionLabel}
|
||||||
|
</button>
|
||||||
|
{operation.badge && <Badge>{operation.badge}</Badge>}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const OperationsList = () => {
|
||||||
|
const handleAction = useCallback((id: string) => {
|
||||||
|
const op = operations.find(o => o.id === id)
|
||||||
|
op?.onClick()
|
||||||
|
}, [operations])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{operations.map(op => (
|
||||||
|
<OperationItem
|
||||||
|
key={op.id}
|
||||||
|
operation={op}
|
||||||
|
onAction={handleAction}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure Patterns
|
||||||
|
|
||||||
|
### Pattern A: Flat Structure (Simple Components)
|
||||||
|
|
||||||
|
For components with 2-3 sub-components:
|
||||||
|
|
||||||
|
```
|
||||||
|
component-name/
|
||||||
|
├── index.tsx # Main component
|
||||||
|
├── sub-component-a.tsx
|
||||||
|
├── sub-component-b.tsx
|
||||||
|
└── types.ts # Shared types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern B: Nested Structure (Complex Components)
|
||||||
|
|
||||||
|
For components with many sub-components:
|
||||||
|
|
||||||
|
```
|
||||||
|
component-name/
|
||||||
|
├── index.tsx # Main orchestration
|
||||||
|
├── types.ts # Shared types
|
||||||
|
├── hooks/
|
||||||
|
│ ├── use-feature-a.ts
|
||||||
|
│ └── use-feature-b.ts
|
||||||
|
├── components/
|
||||||
|
│ ├── header/
|
||||||
|
│ │ └── index.tsx
|
||||||
|
│ ├── content/
|
||||||
|
│ │ └── index.tsx
|
||||||
|
│ └── modals/
|
||||||
|
│ └── index.tsx
|
||||||
|
└── utils/
|
||||||
|
└── helpers.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern C: Feature-Based Structure (Dify Standard)
|
||||||
|
|
||||||
|
Following Dify's existing patterns:
|
||||||
|
|
||||||
|
```
|
||||||
|
configuration/
|
||||||
|
├── index.tsx # Main page component
|
||||||
|
├── base/ # Base/shared components
|
||||||
|
│ ├── feature-panel/
|
||||||
|
│ ├── group-name/
|
||||||
|
│ └── operation-btn/
|
||||||
|
├── config/ # Config section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ ├── agent/
|
||||||
|
│ └── automatic/
|
||||||
|
├── dataset-config/ # Dataset section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ ├── card-item/
|
||||||
|
│ └── params-config/
|
||||||
|
├── debug/ # Debug section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ └── hooks.tsx
|
||||||
|
└── hooks/ # Shared hooks
|
||||||
|
└── use-advanced-prompt-config.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
## Props Design
|
||||||
|
|
||||||
|
### Minimal Props Principle
|
||||||
|
|
||||||
|
Pass only what's needed:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Bad: Passing entire objects when only some fields needed
|
||||||
|
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
|
||||||
|
|
||||||
|
// ✅ Good: Destructure to minimum required
|
||||||
|
<ConfigHeader
|
||||||
|
appName={appDetail.name}
|
||||||
|
isAdvancedMode={modelConfig.isAdvanced}
|
||||||
|
onPublish={handlePublish}
|
||||||
|
/>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Callback Props Pattern
|
||||||
|
|
||||||
|
Use callbacks for child-to-parent communication:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Parent
|
||||||
|
const Parent = () => {
|
||||||
|
const [value, setValue] = useState('')
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Child
|
||||||
|
value={value}
|
||||||
|
onChange={setValue}
|
||||||
|
onSubmit={handleSubmit}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Child
|
||||||
|
interface ChildProps {
|
||||||
|
value: string
|
||||||
|
onChange: (value: string) => void
|
||||||
|
onSubmit: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<input value={value} onChange={e => onChange(e.target.value)} />
|
||||||
|
<button onClick={onSubmit}>Submit</button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Render Props for Flexibility
|
||||||
|
|
||||||
|
When sub-components need parent context:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface ListProps<T> {
|
||||||
|
items: T[]
|
||||||
|
renderItem: (item: T, index: number) => React.ReactNode
|
||||||
|
renderEmpty?: () => React.ReactNode
|
||||||
|
}
|
||||||
|
|
||||||
|
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
|
||||||
|
if (items.length === 0 && renderEmpty) {
|
||||||
|
return <>{renderEmpty()}</>
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{items.map((item, index) => renderItem(item, index))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage
|
||||||
|
<List
|
||||||
|
items={operations}
|
||||||
|
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
|
||||||
|
renderEmpty={() => <EmptyState message="No operations" />}
|
||||||
|
/>
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,317 @@
|
||||||
|
# Hook Extraction Patterns
|
||||||
|
|
||||||
|
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
|
||||||
|
|
||||||
|
## When to Extract Hooks
|
||||||
|
|
||||||
|
Extract a custom hook when you identify:
|
||||||
|
|
||||||
|
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
|
||||||
|
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
|
||||||
|
1. **Business logic** - Data transformations, validations, or calculations
|
||||||
|
1. **Reusable patterns** - Logic that appears in multiple components
|
||||||
|
|
||||||
|
## Extraction Process
|
||||||
|
|
||||||
|
### Step 1: Identify State Groups
|
||||||
|
|
||||||
|
Look for state variables that are logically related:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ These belong together - extract to hook
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
|
||||||
|
|
||||||
|
// These are model-related state that should be in useModelConfig()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Identify Related Effects
|
||||||
|
|
||||||
|
Find effects that modify the grouped state:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ These effects belong with the state above
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasFetchedDetail && !modelModeType) {
|
||||||
|
const mode = currModel?.model_properties.mode
|
||||||
|
if (mode) {
|
||||||
|
const newModelConfig = produce(modelConfig, (draft) => {
|
||||||
|
draft.mode = mode
|
||||||
|
})
|
||||||
|
setModelConfig(newModelConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Create the Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// hooks/use-model-config.ts
|
||||||
|
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
|
import type { ModelConfig } from '@/models/debug'
|
||||||
|
import { produce } from 'immer'
|
||||||
|
import { useEffect, useState } from 'react'
|
||||||
|
import { ModelModeType } from '@/types/app'
|
||||||
|
|
||||||
|
interface UseModelConfigParams {
|
||||||
|
initialConfig?: Partial<ModelConfig>
|
||||||
|
currModel?: { model_properties?: { mode?: ModelModeType } }
|
||||||
|
hasFetchedDetail: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UseModelConfigReturn {
|
||||||
|
modelConfig: ModelConfig
|
||||||
|
setModelConfig: (config: ModelConfig) => void
|
||||||
|
completionParams: FormValue
|
||||||
|
setCompletionParams: (params: FormValue) => void
|
||||||
|
modelModeType: ModelModeType
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useModelConfig = ({
|
||||||
|
initialConfig,
|
||||||
|
currModel,
|
||||||
|
hasFetchedDetail,
|
||||||
|
}: UseModelConfigParams): UseModelConfigReturn => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>({
|
||||||
|
provider: 'langgenius/openai/openai',
|
||||||
|
model_id: 'gpt-3.5-turbo',
|
||||||
|
mode: ModelModeType.unset,
|
||||||
|
// ... default values
|
||||||
|
...initialConfig,
|
||||||
|
})
|
||||||
|
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
const modelModeType = modelConfig.mode
|
||||||
|
|
||||||
|
// Fill old app data missing model mode
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasFetchedDetail && !modelModeType) {
|
||||||
|
const mode = currModel?.model_properties?.mode
|
||||||
|
if (mode) {
|
||||||
|
setModelConfig(produce(modelConfig, (draft) => {
|
||||||
|
draft.mode = mode
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [hasFetchedDetail, modelModeType, currModel])
|
||||||
|
|
||||||
|
return {
|
||||||
|
modelConfig,
|
||||||
|
setModelConfig,
|
||||||
|
completionParams,
|
||||||
|
setCompletionParams,
|
||||||
|
modelModeType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Update Component
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Before: 50+ lines of state management
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
// ... lots of related state and effects
|
||||||
|
}
|
||||||
|
|
||||||
|
// After: Clean component
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const {
|
||||||
|
modelConfig,
|
||||||
|
setModelConfig,
|
||||||
|
completionParams,
|
||||||
|
setCompletionParams,
|
||||||
|
modelModeType,
|
||||||
|
} = useModelConfig({
|
||||||
|
currModel,
|
||||||
|
hasFetchedDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Component now focuses on UI
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Naming Conventions
|
||||||
|
|
||||||
|
### Hook Names
|
||||||
|
|
||||||
|
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
|
||||||
|
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
|
||||||
|
- Include domain: `useWorkflowVariables`, `useMCPServer`
|
||||||
|
|
||||||
|
### File Names
|
||||||
|
|
||||||
|
- Kebab-case: `use-model-config.ts`
|
||||||
|
- Place in `hooks/` subdirectory when multiple hooks exist
|
||||||
|
- Place alongside component for single-use hooks
|
||||||
|
|
||||||
|
### Return Type Names
|
||||||
|
|
||||||
|
- Suffix with `Return`: `UseModelConfigReturn`
|
||||||
|
- Suffix params with `Params`: `UseModelConfigParams`
|
||||||
|
|
||||||
|
## Common Hook Patterns in Dify
|
||||||
|
|
||||||
|
### 1. Data Fetching Hook (React Query)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Use @tanstack/react-query for data fetching
|
||||||
|
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { get } from '@/service/base'
|
||||||
|
import { useInvalid } from '@/service/use-base'
|
||||||
|
|
||||||
|
const NAME_SPACE = 'appConfig'
|
||||||
|
|
||||||
|
// Query keys for cache management
|
||||||
|
export const appConfigQueryKeys = {
|
||||||
|
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main data hook
|
||||||
|
export const useAppConfig = (appId: string) => {
|
||||||
|
return useQuery({
|
||||||
|
enabled: !!appId,
|
||||||
|
queryKey: appConfigQueryKeys.detail(appId),
|
||||||
|
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||||
|
select: data => data?.model_config || null,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidation hook for refreshing data
|
||||||
|
export const useInvalidAppConfig = () => {
|
||||||
|
return useInvalid([NAME_SPACE])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage in component
|
||||||
|
const Component = () => {
|
||||||
|
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
|
||||||
|
const invalidAppConfig = useInvalidAppConfig()
|
||||||
|
|
||||||
|
const handleRefresh = () => {
|
||||||
|
invalidAppConfig() // Invalidates cache and triggers refetch
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Form State Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Form state + validation + submission
|
||||||
|
export const useConfigForm = (initialValues: ConfigFormValues) => {
|
||||||
|
const [values, setValues] = useState(initialValues)
|
||||||
|
const [errors, setErrors] = useState<Record<string, string>>({})
|
||||||
|
const [isSubmitting, setIsSubmitting] = useState(false)
|
||||||
|
|
||||||
|
const validate = useCallback(() => {
|
||||||
|
const newErrors: Record<string, string> = {}
|
||||||
|
if (!values.name) newErrors.name = 'Name is required'
|
||||||
|
setErrors(newErrors)
|
||||||
|
return Object.keys(newErrors).length === 0
|
||||||
|
}, [values])
|
||||||
|
|
||||||
|
const handleChange = useCallback((field: string, value: any) => {
|
||||||
|
setValues(prev => ({ ...prev, [field]: value }))
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
|
||||||
|
if (!validate()) return
|
||||||
|
setIsSubmitting(true)
|
||||||
|
try {
|
||||||
|
await onSubmit(values)
|
||||||
|
} finally {
|
||||||
|
setIsSubmitting(false)
|
||||||
|
}
|
||||||
|
}, [values, validate])
|
||||||
|
|
||||||
|
return { values, errors, isSubmitting, handleChange, handleSubmit }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Modal State Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Multiple modal management
|
||||||
|
type ModalType = 'edit' | 'delete' | 'duplicate' | null
|
||||||
|
|
||||||
|
export const useModalState = () => {
|
||||||
|
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||||
|
const [modalData, setModalData] = useState<any>(null)
|
||||||
|
|
||||||
|
const openModal = useCallback((type: ModalType, data?: any) => {
|
||||||
|
setActiveModal(type)
|
||||||
|
setModalData(data)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const closeModal = useCallback(() => {
|
||||||
|
setActiveModal(null)
|
||||||
|
setModalData(null)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
activeModal,
|
||||||
|
modalData,
|
||||||
|
openModal,
|
||||||
|
closeModal,
|
||||||
|
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Toggle/Boolean Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Boolean state with convenience methods
|
||||||
|
export const useToggle = (initialValue = false) => {
|
||||||
|
const [value, setValue] = useState(initialValue)
|
||||||
|
|
||||||
|
const toggle = useCallback(() => setValue(v => !v), [])
|
||||||
|
const setTrue = useCallback(() => setValue(true), [])
|
||||||
|
const setFalse = useCallback(() => setValue(false), [])
|
||||||
|
|
||||||
|
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage
|
||||||
|
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Extracted Hooks
|
||||||
|
|
||||||
|
After extraction, test hooks in isolation:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// use-model-config.spec.ts
|
||||||
|
import { renderHook, act } from '@testing-library/react'
|
||||||
|
import { useModelConfig } from './use-model-config'
|
||||||
|
|
||||||
|
describe('useModelConfig', () => {
|
||||||
|
it('should initialize with default values', () => {
|
||||||
|
const { result } = renderHook(() => useModelConfig({
|
||||||
|
hasFetchedDetail: false,
|
||||||
|
}))
|
||||||
|
|
||||||
|
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
|
||||||
|
expect(result.current.modelModeType).toBe(ModelModeType.unset)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update model config', () => {
|
||||||
|
const { result } = renderHook(() => useModelConfig({
|
||||||
|
hasFetchedDetail: true,
|
||||||
|
}))
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
result.current.setModelConfig({
|
||||||
|
...result.current.modelConfig,
|
||||||
|
model_id: 'gpt-4',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.current.modelConfig.model_id).toBe('gpt-4')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
@ -318,5 +318,5 @@ For more detailed information, refer to:
|
||||||
|
|
||||||
- `web/vitest.config.ts` - Vitest configuration
|
- `web/vitest.config.ts` - Vitest configuration
|
||||||
- `web/vitest.setup.ts` - Test environment setup
|
- `web/vitest.setup.ts` - Test environment setup
|
||||||
- `web/testing/analyze-component.js` - Component analysis tool
|
- `web/scripts/analyze-component.js` - Component analysis tool
|
||||||
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||||
|
|
|
||||||
|
|
@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event'
|
||||||
|
|
||||||
// i18n (automatically mocked)
|
// i18n (automatically mocked)
|
||||||
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||||
// No explicit mock needed - it returns translation keys as-is
|
// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
|
||||||
|
// No explicit mock needed for most tests
|
||||||
|
//
|
||||||
// Override only if custom translations are required:
|
// Override only if custom translations are required:
|
||||||
// vi.mock('react-i18next', () => ({
|
// import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||||
// useTranslation: () => ({
|
// vi.mock('react-i18next', () => createReactI18nextMock({
|
||||||
// t: (key: string) => {
|
// 'my.custom.key': 'Custom Translation',
|
||||||
// const customTranslations: Record<string, string> = {
|
// 'button.save': 'Save',
|
||||||
// 'my.custom.key': 'Custom Translation',
|
|
||||||
// }
|
|
||||||
// return customTranslations[key] || key
|
|
||||||
// },
|
|
||||||
// }),
|
|
||||||
// }))
|
// }))
|
||||||
|
|
||||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||||
|
|
|
||||||
|
|
@ -52,23 +52,29 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global
|
||||||
### 1. i18n (Auto-loaded via Global Mock)
|
### 1. i18n (Auto-loaded via Global Mock)
|
||||||
|
|
||||||
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||||
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
|
||||||
|
|
||||||
For tests requiring custom translations, override the mock:
|
The global mock provides:
|
||||||
|
|
||||||
|
- `useTranslation` - returns translation keys with namespace prefix
|
||||||
|
- `Trans` component - renders i18nKey and components
|
||||||
|
- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
|
||||||
|
- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
|
||||||
|
|
||||||
|
**Default behavior**: Most tests should use the global mock (no local override needed).
|
||||||
|
|
||||||
|
**For custom translations**: Use the helper function from `@/test/i18n-mock`:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
vi.mock('react-i18next', () => ({
|
import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||||
useTranslation: () => ({
|
|
||||||
t: (key: string) => {
|
vi.mock('react-i18next', () => createReactI18nextMock({
|
||||||
const translations: Record<string, string> = {
|
'my.custom.key': 'Custom translation',
|
||||||
'my.custom.key': 'Custom translation',
|
'button.save': 'Save',
|
||||||
}
|
|
||||||
return translations[key] || key
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
}))
|
}))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
|
||||||
|
|
||||||
### 2. Next.js Router
|
### 2. Next.js Router
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
|
|
|
||||||
|
|
@ -22,12 +22,12 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
@ -57,7 +57,7 @@ jobs:
|
||||||
run: sh .github/workflows/expose_service_ports.sh
|
run: sh .github/workflows/expose_service_ports.sh
|
||||||
|
|
||||||
- name: Set up Sandbox
|
- name: Set up Sandbox
|
||||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
uses: hoverkraft-tech/compose-action@v2
|
||||||
with:
|
with:
|
||||||
compose-file: |
|
compose-file: |
|
||||||
docker/docker-compose.middleware.yaml
|
docker/docker-compose.middleware.yaml
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ jobs:
|
||||||
if: github.repository == 'langgenius/dify'
|
if: github.repository == 'langgenius/dify'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Check Docker Compose inputs
|
- name: Check Docker Compose inputs
|
||||||
id: docker-compose-changes
|
id: docker-compose-changes
|
||||||
|
|
@ -27,7 +27,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@v6
|
- uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Generate Docker Compose
|
- name: Generate Docker Compose
|
||||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ jobs:
|
||||||
touch "/tmp/digests/${sanitized_digest}"
|
touch "/tmp/digests/${sanitized_digest}"
|
||||||
|
|
||||||
- name: Upload digest
|
- name: Upload digest
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||||
path: /tmp/digests/*
|
path: /tmp/digests/*
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,13 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|
@ -63,13 +63,13 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ jobs:
|
||||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
- uses: dorny/paths-filter@v3
|
- uses: dorny/paths-filter@v3
|
||||||
id: changes
|
id: changes
|
||||||
with:
|
with:
|
||||||
|
|
@ -38,6 +38,7 @@ jobs:
|
||||||
- '.github/workflows/api-tests.yml'
|
- '.github/workflows/api-tests.yml'
|
||||||
web:
|
web:
|
||||||
- 'web/**'
|
- 'web/**'
|
||||||
|
- '.github/workflows/web-tests.yml'
|
||||||
vdb:
|
vdb:
|
||||||
- 'api/core/rag/datasource/**'
|
- 'api/core/rag/datasource/**'
|
||||||
- 'docker/**'
|
- 'docker/**'
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,13 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v46
|
uses: tj-actions/changed-files@v47
|
||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
api/**
|
api/**
|
||||||
|
|
@ -33,7 +33,7 @@ jobs:
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: false
|
enable-cache: false
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|
@ -68,15 +68,17 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v46
|
uses: tj-actions/changed-files@v47
|
||||||
with:
|
with:
|
||||||
files: web/**
|
files: |
|
||||||
|
web/**
|
||||||
|
.github/workflows/style.yml
|
||||||
|
|
||||||
- name: Install pnpm
|
- name: Install pnpm
|
||||||
uses: pnpm/action-setup@v4
|
uses: pnpm/action-setup@v4
|
||||||
|
|
@ -85,7 +87,7 @@ jobs:
|
||||||
run_install: false
|
run_install: false
|
||||||
|
|
||||||
- name: Setup NodeJS
|
- name: Setup NodeJS
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
|
|
@ -108,20 +110,30 @@ jobs:
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run type-check:tsgo
|
run: pnpm run type-check:tsgo
|
||||||
|
|
||||||
|
- name: Web dead code check
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
working-directory: ./web
|
||||||
|
run: pnpm run knip
|
||||||
|
|
||||||
|
- name: Web build check
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
working-directory: ./web
|
||||||
|
run: pnpm run build
|
||||||
|
|
||||||
superlinter:
|
superlinter:
|
||||||
name: SuperLinter
|
name: SuperLinter
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
- name: Check changed files
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v46
|
uses: tj-actions/changed-files@v47
|
||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
**.sh
|
**.sh
|
||||||
|
|
|
||||||
|
|
@ -25,12 +25,12 @@ jobs:
|
||||||
working-directory: sdks/nodejs-client
|
working-directory: sdks/nodejs-client
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Use Node.js ${{ matrix.node-version }}
|
- name: Use Node.js ${{ matrix.node-version }}
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: ${{ matrix.node-version }}
|
node-version: ${{ matrix.node-version }}
|
||||||
cache: ''
|
cache: ''
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ on:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
paths:
|
paths:
|
||||||
- 'web/i18n/en-US/*.ts'
|
- 'web/i18n/en-US/*.json'
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
|
@ -18,7 +18,7 @@ jobs:
|
||||||
run:
|
run:
|
||||||
working-directory: web
|
working-directory: web
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
@ -28,13 +28,13 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
git fetch origin "${{ github.event.before }}" || true
|
git fetch origin "${{ github.event.before }}" || true
|
||||||
git fetch origin "${{ github.sha }}" || true
|
git fetch origin "${{ github.sha }}" || true
|
||||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
|
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
|
||||||
echo "Changed files: $changed_files"
|
echo "Changed files: $changed_files"
|
||||||
if [ -n "$changed_files" ]; then
|
if [ -n "$changed_files" ]; then
|
||||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||||
file_args=""
|
file_args=""
|
||||||
for file in $changed_files; do
|
for file in $changed_files; do
|
||||||
filename=$(basename "$file" .ts)
|
filename=$(basename "$file" .json)
|
||||||
file_args="$file_args --file $filename"
|
file_args="$file_args --file $filename"
|
||||||
done
|
done
|
||||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||||
|
|
@ -51,7 +51,7 @@ jobs:
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: 'lts/*'
|
node-version: 'lts/*'
|
||||||
cache: pnpm
|
cache: pnpm
|
||||||
|
|
@ -65,7 +65,7 @@ jobs:
|
||||||
- name: Generate i18n translations
|
- name: Generate i18n translations
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
|
||||||
|
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
|
|
|
||||||
|
|
@ -19,19 +19,19 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Free Disk Space
|
- name: Free Disk Space
|
||||||
uses: endersonmenezes/free-disk-space@v2
|
uses: endersonmenezes/free-disk-space@v3
|
||||||
with:
|
with:
|
||||||
remove_dotnet: true
|
remove_dotnet: true
|
||||||
remove_haskell: true
|
remove_haskell: true
|
||||||
remove_tool_cache: true
|
remove_tool_cache: true
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
|
@ -29,7 +29,7 @@ jobs:
|
||||||
run_install: false
|
run_install: false
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
cache: pnpm
|
cache: pnpm
|
||||||
|
|
@ -360,7 +360,7 @@ jobs:
|
||||||
|
|
||||||
- name: Upload Coverage Artifact
|
- name: Upload Coverage Artifact
|
||||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: web-coverage-report
|
name: web-coverage-report
|
||||||
path: web/coverage
|
path: web/coverage
|
||||||
|
|
|
||||||
34
.mcp.json
34
.mcp.json
|
|
@ -1,34 +0,0 @@
|
||||||
{
|
|
||||||
"mcpServers": {
|
|
||||||
"context7": {
|
|
||||||
"type": "http",
|
|
||||||
"url": "https://mcp.context7.com/mcp"
|
|
||||||
},
|
|
||||||
"sequential-thinking": {
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "npx",
|
|
||||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
|
||||||
"env": {}
|
|
||||||
},
|
|
||||||
"github": {
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "npx",
|
|
||||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
|
||||||
"env": {
|
|
||||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fetch": {
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "uvx",
|
|
||||||
"args": ["mcp-server-fetch"],
|
|
||||||
"env": {}
|
|
||||||
},
|
|
||||||
"playwright": {
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "npx",
|
|
||||||
"args": ["-y", "@playwright/mcp@latest"],
|
|
||||||
"env": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key
|
||||||
S3_SECRET_KEY=your-secret-key
|
S3_SECRET_KEY=your-secret-key
|
||||||
S3_REGION=your-region
|
S3_REGION=your-region
|
||||||
|
|
||||||
|
# Workflow run and Conversation archive storage (S3-compatible)
|
||||||
|
ARCHIVE_STORAGE_ENABLED=false
|
||||||
|
ARCHIVE_STORAGE_ENDPOINT=
|
||||||
|
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
|
||||||
|
ARCHIVE_STORAGE_EXPORT_BUCKET=
|
||||||
|
ARCHIVE_STORAGE_ACCESS_KEY=
|
||||||
|
ARCHIVE_STORAGE_SECRET_KEY=
|
||||||
|
ARCHIVE_STORAGE_REGION=auto
|
||||||
|
|
||||||
# Azure Blob Storage configuration
|
# Azure Blob Storage configuration
|
||||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||||
|
|
@ -128,6 +137,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key
|
||||||
TENCENT_COS_SECRET_ID=your-secret-id
|
TENCENT_COS_SECRET_ID=your-secret-id
|
||||||
TENCENT_COS_REGION=your-region
|
TENCENT_COS_REGION=your-region
|
||||||
TENCENT_COS_SCHEME=your-scheme
|
TENCENT_COS_SCHEME=your-scheme
|
||||||
|
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
|
||||||
|
|
||||||
# Huawei OBS Storage Configuration
|
# Huawei OBS Storage Configuration
|
||||||
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
exclude = ["migrations/*"]
|
exclude = [
|
||||||
|
"migrations/*",
|
||||||
|
".git",
|
||||||
|
".git/**",
|
||||||
|
]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
|
||||||
[format]
|
[format]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
|
from configs.extra.archive_config import ArchiveStorageConfig
|
||||||
from configs.extra.notion_config import NotionConfig
|
from configs.extra.notion_config import NotionConfig
|
||||||
from configs.extra.sentry_config import SentryConfig
|
from configs.extra.sentry_config import SentryConfig
|
||||||
|
|
||||||
|
|
||||||
class ExtraServiceConfig(
|
class ExtraServiceConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
|
ArchiveStorageConfig,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
SentryConfig,
|
SentryConfig,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveStorageConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration settings for workflow run logs archiving storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ARCHIVE_STORAGE_ENABLED: bool = Field(
|
||||||
|
description="Enable workflow run logs archiving to S3-compatible storage",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
|
||||||
|
description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
|
||||||
|
description="Name of the bucket to store archived workflow logs",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
|
||||||
|
description="Name of the bucket to store exported workflow runs",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
|
||||||
|
description="Access key ID for authenticating with storage",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
|
||||||
|
description="Secret access key for authenticating with storage",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
ARCHIVE_STORAGE_REGION: str = Field(
|
||||||
|
description="Region for storage (use 'auto' if the provider supports it)",
|
||||||
|
default="auto",
|
||||||
|
)
|
||||||
|
|
@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings):
|
||||||
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
|
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
|
||||||
|
description="Tencent Cloud COS custom domain setting",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,62 +1,59 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.helper import AppIconUrlField
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
parameters__system_parameters = {
|
from pydantic import BaseModel, ConfigDict, computed_field
|
||||||
"image_file_size_limit": fields.Integer,
|
|
||||||
"video_file_size_limit": fields.Integer,
|
from core.file import helpers as file_helpers
|
||||||
"audio_file_size_limit": fields.Integer,
|
from models.model import IconType
|
||||||
"file_size_limit": fields.Integer,
|
|
||||||
"workflow_file_upload_limit": fields.Integer,
|
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||||
}
|
JSONObject: TypeAlias = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
def build_system_parameters_model(api_or_ns: Api | Namespace):
|
class SystemParameters(BaseModel):
|
||||||
"""Build the system parameters model for the API or Namespace."""
|
image_file_size_limit: int
|
||||||
return api_or_ns.model("SystemParameters", parameters__system_parameters)
|
video_file_size_limit: int
|
||||||
|
audio_file_size_limit: int
|
||||||
|
file_size_limit: int
|
||||||
|
workflow_file_upload_limit: int
|
||||||
|
|
||||||
|
|
||||||
parameters_fields = {
|
class Parameters(BaseModel):
|
||||||
"opening_statement": fields.String,
|
opening_statement: str | None = None
|
||||||
"suggested_questions": fields.Raw,
|
suggested_questions: list[str]
|
||||||
"suggested_questions_after_answer": fields.Raw,
|
suggested_questions_after_answer: JSONObject
|
||||||
"speech_to_text": fields.Raw,
|
speech_to_text: JSONObject
|
||||||
"text_to_speech": fields.Raw,
|
text_to_speech: JSONObject
|
||||||
"retriever_resource": fields.Raw,
|
retriever_resource: JSONObject
|
||||||
"annotation_reply": fields.Raw,
|
annotation_reply: JSONObject
|
||||||
"more_like_this": fields.Raw,
|
more_like_this: JSONObject
|
||||||
"user_input_form": fields.Raw,
|
user_input_form: list[JSONObject]
|
||||||
"sensitive_word_avoidance": fields.Raw,
|
sensitive_word_avoidance: JSONObject
|
||||||
"file_upload": fields.Raw,
|
file_upload: JSONObject
|
||||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
system_parameters: SystemParameters
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_parameters_model(api_or_ns: Api | Namespace):
|
class Site(BaseModel):
|
||||||
"""Build the parameters model for the API or Namespace."""
|
model_config = ConfigDict(from_attributes=True)
|
||||||
copied_fields = parameters_fields.copy()
|
|
||||||
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
|
|
||||||
return api_or_ns.model("Parameters", copied_fields)
|
|
||||||
|
|
||||||
|
title: str
|
||||||
|
chat_color_theme: str | None = None
|
||||||
|
chat_color_theme_inverted: bool
|
||||||
|
icon_type: str | None = None
|
||||||
|
icon: str | None = None
|
||||||
|
icon_background: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
copyright: str | None = None
|
||||||
|
privacy_policy: str | None = None
|
||||||
|
custom_disclaimer: str | None = None
|
||||||
|
default_language: str
|
||||||
|
show_workflow_steps: bool
|
||||||
|
use_icon_as_answer_icon: bool
|
||||||
|
|
||||||
site_fields = {
|
@computed_field(return_type=str | None) # type: ignore
|
||||||
"title": fields.String,
|
@property
|
||||||
"chat_color_theme": fields.String,
|
def icon_url(self) -> str | None:
|
||||||
"chat_color_theme_inverted": fields.Boolean,
|
if self.icon and self.icon_type == IconType.IMAGE:
|
||||||
"icon_type": fields.String,
|
return file_helpers.get_signed_file_url(self.icon)
|
||||||
"icon": fields.String,
|
return None
|
||||||
"icon_background": fields.String,
|
|
||||||
"icon_url": AppIconUrlField,
|
|
||||||
"description": fields.String,
|
|
||||||
"copyright": fields.String,
|
|
||||||
"privacy_policy": fields.String,
|
|
||||||
"custom_disclaimer": fields.String,
|
|
||||||
"default_language": fields.String,
|
|
||||||
"show_workflow_steps": fields.Boolean,
|
|
||||||
"use_icon_as_answer_icon": fields.Boolean,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_site_model(api_or_ns: Api | Namespace):
|
|
||||||
"""Build the site model for the API or Namespace."""
|
|
||||||
return api_or_ns.model("Site", site_fields)
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
|
@ -73,6 +74,48 @@ class AppListQuery(BaseModel):
|
||||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# XSS prevention: patterns that could lead to XSS attacks
|
||||||
|
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
|
||||||
|
_XSS_PATTERNS = [
|
||||||
|
r"<script[^>]*>.*?</script>", # Script tags
|
||||||
|
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
|
||||||
|
r"javascript:", # JavaScript protocol
|
||||||
|
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
|
||||||
|
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
|
||||||
|
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
|
||||||
|
r"<embed[^>]*>", # Embed tags (self-closing)
|
||||||
|
r"<link[^>]*>", # Link tags with javascript
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
|
||||||
|
"""
|
||||||
|
Validate that a string value doesn't contain potential XSS payloads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The string value to validate
|
||||||
|
field_name: Name of the field for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The original value if safe
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the value contains XSS patterns
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
value_lower = value.lower()
|
||||||
|
for pattern in _XSS_PATTERNS:
|
||||||
|
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
|
||||||
|
raise ValueError(
|
||||||
|
f"{field_name} contains invalid characters or patterns. "
|
||||||
|
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class CreateAppPayload(BaseModel):
|
class CreateAppPayload(BaseModel):
|
||||||
name: str = Field(..., min_length=1, description="App name")
|
name: str = Field(..., min_length=1, description="App name")
|
||||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||||
|
|
@ -81,6 +124,11 @@ class CreateAppPayload(BaseModel):
|
||||||
icon: str | None = Field(default=None, description="Icon")
|
icon: str | None = Field(default=None, description="Icon")
|
||||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||||
|
|
||||||
|
@field_validator("name", "description", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||||
|
return _validate_xss_safe(value, info.field_name)
|
||||||
|
|
||||||
|
|
||||||
class UpdateAppPayload(BaseModel):
|
class UpdateAppPayload(BaseModel):
|
||||||
name: str = Field(..., min_length=1, description="App name")
|
name: str = Field(..., min_length=1, description="App name")
|
||||||
|
|
@ -91,6 +139,11 @@ class UpdateAppPayload(BaseModel):
|
||||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||||
|
|
||||||
|
@field_validator("name", "description", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||||
|
return _validate_xss_safe(value, info.field_name)
|
||||||
|
|
||||||
|
|
||||||
class CopyAppPayload(BaseModel):
|
class CopyAppPayload(BaseModel):
|
||||||
name: str | None = Field(default=None, description="Name for the copied app")
|
name: str | None = Field(default=None, description="Name for the copied app")
|
||||||
|
|
@ -99,6 +152,11 @@ class CopyAppPayload(BaseModel):
|
||||||
icon: str | None = Field(default=None, description="Icon")
|
icon: str | None = Field(default=None, description="Icon")
|
||||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||||
|
|
||||||
|
@field_validator("name", "description", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||||
|
return _validate_xss_safe(value, info.field_name)
|
||||||
|
|
||||||
|
|
||||||
class AppExportQuery(BaseModel):
|
class AppExportQuery(BaseModel):
|
||||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,7 @@ class OAuthCallback(Resource):
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
account = _generate_account(provider, user_info)
|
account, oauth_new_user = _generate_account(provider, user_info)
|
||||||
except AccountNotFoundError:
|
except AccountNotFoundError:
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||||
|
|
@ -159,7 +159,10 @@ class OAuthCallback(Resource):
|
||||||
ip_address=extract_remote_ip(request),
|
ip_address=extract_remote_ip(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
base_url = dify_config.CONSOLE_WEB_URL
|
||||||
|
query_char = "&" if "?" in base_url else "?"
|
||||||
|
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
|
||||||
|
response = redirect(target_url)
|
||||||
|
|
||||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||||
|
|
@ -177,9 +180,10 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
def _generate_account(provider: str, user_info: OAuthUserInfo):
|
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
|
||||||
# Get account by openid or email.
|
# Get account by openid or email.
|
||||||
account = _get_account_by_openid_or_email(provider, user_info)
|
account = _get_account_by_openid_or_email(provider, user_info)
|
||||||
|
oauth_new_user = False
|
||||||
|
|
||||||
if account:
|
if account:
|
||||||
tenants = TenantService.get_join_tenants(account)
|
tenants = TenantService.get_join_tenants(account)
|
||||||
|
|
@ -193,6 +197,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||||
tenant_was_created.send(new_tenant)
|
tenant_was_created.send(new_tenant)
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
|
oauth_new_user = True
|
||||||
if not FeatureService.get_system_features().is_allow_register:
|
if not FeatureService.get_system_features().is_allow_register:
|
||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||||
raise AccountRegisterError(
|
raise AccountRegisterError(
|
||||||
|
|
@ -220,4 +225,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||||
# Link account
|
# Link account
|
||||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||||
|
|
||||||
return account
|
return account, oauth_new_user
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import base64
|
import base64
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionQuery(BaseModel):
|
class SubscriptionQuery(BaseModel):
|
||||||
plan: str = Field(..., description="Subscription plan")
|
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||||
interval: str = Field(..., description="Billing interval")
|
interval: Literal["month", "year"] = Field(..., description="Billing interval")
|
||||||
|
|
||||||
@field_validator("plan")
|
|
||||||
@classmethod
|
|
||||||
def validate_plan(cls, value: str) -> str:
|
|
||||||
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
|
|
||||||
raise ValueError("Invalid plan")
|
|
||||||
return value
|
|
||||||
|
|
||||||
@field_validator("interval")
|
|
||||||
@classmethod
|
|
||||||
def validate_interval(cls, value: str) -> str:
|
|
||||||
if value not in {"month", "year"}:
|
|
||||||
raise ValueError("Invalid interval")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class PartnerTenantsPayload(BaseModel):
|
class PartnerTenantsPayload(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,12 @@ import uuid
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, marshal
|
from flask_restx import Resource, marshal
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import String, cast, func, or_, select
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from configs import dify_config
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import ProviderNotInitializeError
|
from controllers.console.app.error import ProviderNotInitializeError
|
||||||
|
|
@ -143,7 +145,29 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||||
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
||||||
|
|
||||||
if keyword:
|
if keyword:
|
||||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
# Search in both content and keywords fields
|
||||||
|
# Use database-specific methods for JSON array search
|
||||||
|
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||||
|
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||||
|
keywords_condition = func.array_to_string(
|
||||||
|
func.array(
|
||||||
|
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||||
|
.correlate(DocumentSegment)
|
||||||
|
.scalar_subquery()
|
||||||
|
),
|
||||||
|
",",
|
||||||
|
).ilike(f"%{keyword}%")
|
||||||
|
else:
|
||||||
|
# MySQL: Cast JSON to string for pattern matching
|
||||||
|
# MySQL stores Chinese text directly in JSON without Unicode escaping
|
||||||
|
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
|
||||||
|
|
||||||
|
query = query.where(
|
||||||
|
or_(
|
||||||
|
DocumentSegment.content.ilike(f"%{keyword}%"),
|
||||||
|
keywords_condition,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if args.enabled.lower() != "all":
|
if args.enabled.lower() != "all":
|
||||||
if args.enabled.lower() == "true":
|
if args.enabled.lower() == "true":
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import marshal_with
|
from flask_restx import marshal_with
|
||||||
|
|
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
|
|
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MessageListQuery(BaseModel):
|
class MessageListQuery(BaseModel):
|
||||||
conversation_id: UUID
|
conversation_id: UUIDStrOrEmpty
|
||||||
first_id: UUID | None = None
|
first_id: UUIDStrOrEmpty | None = None
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
from flask_restx import marshal_with
|
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import AppUnavailableError
|
from controllers.console.app.error import AppUnavailableError
|
||||||
|
|
@ -13,7 +11,6 @@ from services.app_service import AppService
|
||||||
class AppParameterApi(InstalledAppResource):
|
class AppParameterApi(InstalledAppResource):
|
||||||
"""Resource for app variables."""
|
"""Resource for app variables."""
|
||||||
|
|
||||||
@marshal_with(fields.parameters_fields)
|
|
||||||
def get(self, installed_app: InstalledApp):
|
def get(self, installed_app: InstalledApp):
|
||||||
"""Retrieve app parameters."""
|
"""Retrieve app parameters."""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
|
||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import fields, marshal_with
|
from flask_restx import fields, marshal_with
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
@ -10,19 +8,19 @@ from controllers.console import console_ns
|
||||||
from controllers.console.explore.error import NotCompletionAppError
|
from controllers.console.explore.error import NotCompletionAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageListQuery(BaseModel):
|
class SavedMessageListQuery(BaseModel):
|
||||||
last_id: UUID | None = None
|
last_id: UUIDStrOrEmpty | None = None
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageCreatePayload(BaseModel):
|
class SavedMessageCreatePayload(BaseModel):
|
||||||
message_id: UUID
|
message_id: UUIDStrOrEmpty
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
@ -10,10 +12,20 @@ from models import TenantAccountRole
|
||||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBalancingCredentialPayload(BaseModel):
|
||||||
|
model: str
|
||||||
|
model_type: ModelType
|
||||||
|
credentials: dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, LoadBalancingCredentialPayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||||
)
|
)
|
||||||
class LoadBalancingCredentialsValidateApi(Resource):
|
class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
|
|
||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
parser = (
|
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"model_type",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
nullable=False,
|
|
||||||
choices=[mt.value for mt in ModelType],
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# validate model load balancing credentials
|
# validate model load balancing credentials
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
|
|
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
model_load_balancing_service.validate_load_balancing_credentials(
|
model_load_balancing_service.validate_load_balancing_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args["model"],
|
model=payload.model,
|
||||||
model_type=args["model_type"],
|
model_type=payload.model_type,
|
||||||
credentials=args["credentials"],
|
credentials=payload.credentials,
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
result = False
|
result = False
|
||||||
|
|
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||||
)
|
)
|
||||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
|
|
||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
parser = (
|
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"model_type",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
nullable=False,
|
|
||||||
choices=[mt.value for mt in ModelType],
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# validate model load balancing config credentials
|
# validate model load balancing config credentials
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
|
|
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
model_load_balancing_service.validate_load_balancing_credentials(
|
model_load_balancing_service.validate_load_balancing_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args["model"],
|
model=payload.model,
|
||||||
model_type=args["model_type"],
|
model_type=payload.model_type,
|
||||||
credentials=args["credentials"],
|
credentials=payload.credentials,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from flask import make_response, redirect, request, send_file
|
from flask import make_response, redirect, request, send_file
|
||||||
|
|
@ -17,8 +18,8 @@ from controllers.console.wraps import (
|
||||||
is_admin_or_owner_required,
|
is_admin_or_owner_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
|
||||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
|
|
@ -40,6 +41,8 @@ from services.tools.tools_manage_service import ToolCommonService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_valid_url(url: str) -> bool:
|
def is_valid_url(url: str) -> bool:
|
||||||
if not url:
|
if not url:
|
||||||
|
|
@ -945,8 +948,8 @@ class ToolProviderMCPApi(Resource):
|
||||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
|
|
||||||
# Create provider in transaction
|
# 1) Create provider in a short transaction (no network I/O inside)
|
||||||
with Session(db.engine) as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
result = service.create_provider(
|
result = service.create_provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
@ -962,8 +965,26 @@ class ToolProviderMCPApi(Resource):
|
||||||
authentication=authentication,
|
authentication=authentication,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
# Perform network I/O outside any DB session to avoid holding locks.
|
||||||
|
try:
|
||||||
|
reconnect = MCPToolManageService.reconnect_with_url(
|
||||||
|
server_url=args["server_url"],
|
||||||
|
headers=args.get("headers") or {},
|
||||||
|
timeout=configuration.timeout,
|
||||||
|
sse_read_timeout=configuration.sse_read_timeout,
|
||||||
|
)
|
||||||
|
# Update just-created provider with authed/tools in a new short transaction
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
|
||||||
|
db_provider.authed = reconnect.authed
|
||||||
|
db_provider.tools = reconnect.tools
|
||||||
|
|
||||||
|
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
|
||||||
|
except Exception:
|
||||||
|
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||||
|
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||||
|
|
||||||
return jsonable_encoder(result)
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
|
|
@ -1011,9 +1032,6 @@ class ToolProviderMCPApi(Resource):
|
||||||
validation_result=validation_result,
|
validation_result=validation_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
|
||||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@console_ns.expect(parser_mcp_delete)
|
@console_ns.expect(parser_mcp_delete)
|
||||||
|
|
@ -1028,9 +1046,6 @@ class ToolProviderMCPApi(Resource):
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||||
|
|
||||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
|
||||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1081,8 +1096,6 @@ class ToolMCPAuthApi(Resource):
|
||||||
credentials=provider_entity.credentials,
|
credentials=provider_entity.credentials,
|
||||||
authed=True,
|
authed=True,
|
||||||
)
|
)
|
||||||
# Invalidate cache after updating credentials
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
except MCPAuthError as e:
|
except MCPAuthError as e:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1096,22 +1109,16 @@ class ToolMCPAuthApi(Resource):
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
response = service.execute_auth_actions(auth_result)
|
response = service.execute_auth_actions(auth_result)
|
||||||
# Invalidate cache after auth actions may have updated provider state
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
return response
|
return response
|
||||||
except MCPRefreshTokenError as e:
|
except MCPRefreshTokenError as e:
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
# Invalidate cache after clearing credentials
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||||
except (MCPError, ValueError) as e:
|
except (MCPError, ValueError) as e:
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
# Invalidate cache after clearing credentials
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Api, Namespace, Resource, fields
|
from flask_restx import Namespace, Resource, fields
|
||||||
from flask_restx.api import HTTPStatus
|
from flask_restx.api import HTTPStatus
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -92,7 +92,7 @@ annotation_list_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_annotation_list_model(api_or_ns: Api | Namespace):
|
def build_annotation_list_model(api_or_ns: Namespace):
|
||||||
"""Build the annotation list model for the API or Namespace."""
|
"""Build the annotation list model for the API or Namespace."""
|
||||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
|
|
||||||
from controllers.common.fields import build_parameters_model
|
from controllers.common.fields import Parameters
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import AppUnavailableError
|
from controllers.service_api.app.error import AppUnavailableError
|
||||||
from controllers.service_api.wraps import validate_app_token
|
from controllers.service_api.wraps import validate_app_token
|
||||||
|
|
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
|
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""Retrieve app parameters.
|
"""Retrieve app parameters.
|
||||||
|
|
||||||
|
|
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
|
||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/meta")
|
@service_api_ns.route("/meta")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.common.fields import build_site_model
|
from controllers.common.fields import Site as SiteResponse
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.wraps import validate_app_token
|
from controllers.service_api.wraps import validate_app_token
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""Retrieve app site info.
|
"""Retrieve app site info.
|
||||||
|
|
||||||
|
|
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
|
||||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
return site
|
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import Any, Literal
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Api, Namespace, Resource, fields
|
from flask_restx import Namespace, Resource, fields
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
|
|
@ -78,7 +78,7 @@ workflow_run_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_run_model(api_or_ns: Api | Namespace):
|
def build_workflow_run_model(api_or_ns: Namespace):
|
||||||
"""Build the workflow run model for the API or Namespace."""
|
"""Build the workflow run model for the API or Namespace."""
|
||||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
|
||||||
from controllers.service_api.wraps import (
|
from controllers.service_api.wraps import (
|
||||||
DatasetApiResource,
|
DatasetApiResource,
|
||||||
cloud_edition_billing_rate_limit_check,
|
cloud_edition_billing_rate_limit_check,
|
||||||
validate_dataset_token,
|
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
|
|
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||||
401: "Unauthorized - invalid API token",
|
401: "Unauthorized - invalid API token",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
|
||||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||||
def get(self, _, dataset_id):
|
def get(self, _):
|
||||||
"""Get all knowledge type tags."""
|
"""Get all knowledge type tags."""
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
cid = current_user.current_tenant_id
|
cid = current_user.current_tenant_id
|
||||||
|
|
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||||
@validate_dataset_token
|
def post(self, _):
|
||||||
def post(self, _, dataset_id):
|
|
||||||
"""Add a knowledge type tag."""
|
"""Add a knowledge type tag."""
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
|
|
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||||
@validate_dataset_token
|
def patch(self, _):
|
||||||
def patch(self, _, dataset_id):
|
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||||
403: "Forbidden - insufficient permissions",
|
403: "Forbidden - insufficient permissions",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def delete(self, _, dataset_id):
|
def delete(self, _):
|
||||||
"""Delete a knowledge type tag."""
|
"""Delete a knowledge type tag."""
|
||||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||||
TagService.delete_tag(payload.tag_id)
|
TagService.delete_tag(payload.tag_id)
|
||||||
|
|
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||||
403: "Forbidden - insufficient permissions",
|
403: "Forbidden - insufficient permissions",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
def post(self, _):
|
||||||
def post(self, _, dataset_id):
|
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
|
|
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||||
403: "Forbidden - insufficient permissions",
|
403: "Forbidden - insufficient permissions",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
def post(self, _):
|
||||||
def post(self, _, dataset_id):
|
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
|
|
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||||
401: "Unauthorized - invalid API token",
|
401: "Unauthorized - invalid API token",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_dataset_token
|
|
||||||
def get(self, _, *args, **kwargs):
|
def get(self, _, *args, **kwargs):
|
||||||
"""Get all knowledge type tags."""
|
"""Get all knowledge type tags."""
|
||||||
dataset_id = kwargs.get("dataset_id")
|
dataset_id = kwargs.get("dataset_id")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, marshal_with
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
|
|
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
|
||||||
500: "Internal Server Error",
|
500: "Internal Server Error",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@marshal_with(fields.parameters_fields)
|
|
||||||
def get(self, app_model: App, end_user):
|
def get(self, app_model: App, end_user):
|
||||||
"""Retrieve app parameters."""
|
"""Retrieve app parameters."""
|
||||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
|
|
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
|
||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/meta")
|
@web_ns.route("/meta")
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
|
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||||
from models.model import Message
|
from models.model import Message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||||
self._agent_scratchpad.append(scratchpad)
|
self._agent_scratchpad.append(scratchpad)
|
||||||
|
|
||||||
|
# Check if max iteration is reached and model still wants to call tools
|
||||||
|
if iteration_step == max_iteration_steps and scratchpad.action:
|
||||||
|
if scratchpad.action.action_name.lower() != "final answer":
|
||||||
|
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||||
|
|
||||||
# get llm usage
|
# get llm usage
|
||||||
if "usage" in usage_dict:
|
if "usage" in usage_dict:
|
||||||
if usage_dict["usage"] is not None:
|
if usage_dict["usage"] is not None:
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
|
||||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
|
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||||
from models.model import Message
|
from models.model import Message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -222,6 +223,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
|
|
||||||
final_answer += response + "\n"
|
final_answer += response + "\n"
|
||||||
|
|
||||||
|
# Check if max iteration is reached and model still wants to call tools
|
||||||
|
if iteration_step == max_iteration_steps and tool_calls:
|
||||||
|
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||||
|
|
||||||
# call tools
|
# call tools
|
||||||
tool_responses = []
|
tool_responses = []
|
||||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,7 @@ class AppQueueManager:
|
||||||
"""
|
"""
|
||||||
self._clear_task_belong_cache()
|
self._clear_task_belong_cache()
|
||||||
self._q.put(None)
|
self._q.put(None)
|
||||||
|
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
|
||||||
|
|
||||||
def _clear_task_belong_cache(self) -> None:
|
def _clear_task_belong_cache(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,14 @@
|
||||||
|
from collections.abc import Mapping
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||||
|
|
||||||
|
|
||||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||||
|
# Use separate placeholder for base64-encoded template to avoid confusion
|
||||||
|
_template_b64_placeholder: str = "{{template_b64}}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_response(cls, response: str):
|
def transform_response(cls, response: str):
|
||||||
"""
|
"""
|
||||||
|
|
@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||||
"""
|
"""
|
||||||
return {"result": cls.extract_result_str_from_response(response)}
|
return {"result": cls.extract_result_str_from_response(response)}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Override base class to use base64 encoding for template code.
|
||||||
|
This prevents issues with special characters (quotes, newlines) in templates
|
||||||
|
breaking the generated Python script. Fixes #26818.
|
||||||
|
"""
|
||||||
|
script = cls.get_runner_script()
|
||||||
|
# Encode template as base64 to safely embed any content including quotes
|
||||||
|
code_b64 = cls.serialize_code(code)
|
||||||
|
script = script.replace(cls._template_b64_placeholder, code_b64)
|
||||||
|
inputs_str = cls.serialize_inputs(inputs)
|
||||||
|
script = script.replace(cls._inputs_placeholder, inputs_str)
|
||||||
|
return script
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_runner_script(cls) -> str:
|
def get_runner_script(cls) -> str:
|
||||||
runner_script = dedent(f"""
|
runner_script = dedent(f"""
|
||||||
# declare main function
|
import jinja2
|
||||||
def main(**inputs):
|
|
||||||
import jinja2
|
|
||||||
template = jinja2.Template('''{cls._code_placeholder}''')
|
|
||||||
return template.render(**inputs)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
|
|
||||||
|
# declare main function
|
||||||
|
def main(**inputs):
|
||||||
|
# Decode base64-encoded template to handle special characters safely
|
||||||
|
template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
|
||||||
|
template = jinja2.Template(template_code)
|
||||||
|
return template.render(**inputs)
|
||||||
|
|
||||||
# decode and prepare input dict
|
# decode and prepare input dict
|
||||||
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
|
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,15 @@ class TemplateTransformer(ABC):
|
||||||
_inputs_placeholder: str = "{{inputs}}"
|
_inputs_placeholder: str = "{{inputs}}"
|
||||||
_result_tag: str = "<<RESULT>>"
|
_result_tag: str = "<<RESULT>>"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize_code(cls, code: str) -> str:
|
||||||
|
"""
|
||||||
|
Serialize template code to base64 to safely embed in generated script.
|
||||||
|
This prevents issues with special characters like quotes breaking the script.
|
||||||
|
"""
|
||||||
|
code_bytes = code.encode("utf-8")
|
||||||
|
return b64encode(code_bytes).decode("utf-8")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
|
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,56 +0,0 @@
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
|
||||||
from extensions.ext_redis import redis_client, redis_fallback
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderListCache:
|
|
||||||
"""Cache for tool provider lists"""
|
|
||||||
|
|
||||||
CACHE_TTL = 300 # 5 minutes
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
|
|
||||||
"""Generate cache key for tool providers list"""
|
|
||||||
type_filter = typ or "all"
|
|
||||||
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@redis_fallback(default_return=None)
|
|
||||||
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
|
|
||||||
"""Get cached tool providers"""
|
|
||||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
|
||||||
cached_data = redis_client.get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
try:
|
|
||||||
return json.loads(cached_data.decode("utf-8"))
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
||||||
logger.warning("Failed to decode cached tool providers data")
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@redis_fallback()
|
|
||||||
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
|
|
||||||
"""Cache tool providers"""
|
|
||||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
|
||||||
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@redis_fallback()
|
|
||||||
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
|
||||||
"""Invalidate cache for tool providers"""
|
|
||||||
if typ:
|
|
||||||
# Invalidate specific type cache
|
|
||||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
|
||||||
redis_client.delete(cache_key)
|
|
||||||
else:
|
|
||||||
# Invalidate all caches for this tenant
|
|
||||||
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
|
|
||||||
keys = list(redis_client.scan_iter(pattern))
|
|
||||||
if keys:
|
|
||||||
redis_client.delete(*keys)
|
|
||||||
|
|
@ -313,17 +313,20 @@ class StreamableHTTPTransport:
|
||||||
if is_initialization:
|
if is_initialization:
|
||||||
self._maybe_extract_session_id_from_response(response)
|
self._maybe_extract_session_id_from_response(response)
|
||||||
|
|
||||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
|
||||||
|
# The server MUST NOT send a response to notifications.
|
||||||
|
if isinstance(message.root, JSONRPCRequest):
|
||||||
|
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||||
|
|
||||||
if content_type.startswith(JSON):
|
if content_type.startswith(JSON):
|
||||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||||
elif content_type.startswith(SSE):
|
elif content_type.startswith(SSE):
|
||||||
self._handle_sse_response(response, ctx)
|
self._handle_sse_response(response, ctx)
|
||||||
else:
|
else:
|
||||||
self._handle_unexpected_content_type(
|
self._handle_unexpected_content_type(
|
||||||
content_type,
|
content_type,
|
||||||
ctx.server_to_client_queue,
|
ctx.server_to_client_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_json_response(
|
def _handle_json_response(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
|
||||||
auto_generate: PluginParameterAutoGenerate | None = None
|
auto_generate: PluginParameterAutoGenerate | None = None
|
||||||
template: PluginParameterTemplate | None = None
|
template: PluginParameterTemplate | None = None
|
||||||
required: bool = False
|
required: bool = False
|
||||||
default: Union[float, int, str, bool] | None = None
|
default: Union[float, int, str, bool, list, dict] | None = None
|
||||||
min: Union[float, int] | None = None
|
min: Union[float, int] | None = None
|
||||||
max: Union[float, int] | None = None
|
max: Union[float, int] | None = None
|
||||||
precision: int | None = None
|
precision: int | None = None
|
||||||
|
|
|
||||||
|
|
@ -27,26 +27,44 @@ class CleanProcessor:
|
||||||
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
||||||
text = re.sub(pattern, "", text)
|
text = re.sub(pattern, "", text)
|
||||||
|
|
||||||
# Remove URL but keep Markdown image URLs
|
# Remove URL but keep Markdown image URLs and link URLs
|
||||||
# First, temporarily replace Markdown image URLs with a placeholder
|
# Replace the ENTIRE markdown link/image with a single placeholder to protect
|
||||||
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
|
# the link text (which might also be a URL) from being removed
|
||||||
placeholders: list[str] = []
|
markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
|
||||||
|
markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
|
||||||
|
placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
|
||||||
|
|
||||||
def replace_with_placeholder(match, placeholders=placeholders):
|
def replace_markdown_with_placeholder(match, placeholders=placeholders):
|
||||||
|
link_type = "link"
|
||||||
|
link_text = match.group(1)
|
||||||
|
url = match.group(2)
|
||||||
|
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
|
||||||
|
placeholders.append((link_type, link_text, url))
|
||||||
|
return placeholder
|
||||||
|
|
||||||
|
def replace_image_with_placeholder(match, placeholders=placeholders):
|
||||||
|
link_type = "image"
|
||||||
url = match.group(1)
|
url = match.group(1)
|
||||||
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
|
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
|
||||||
placeholders.append(url)
|
placeholders.append((link_type, "image", url))
|
||||||
return f""
|
return placeholder
|
||||||
|
|
||||||
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
|
# Protect markdown links first
|
||||||
|
text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
|
||||||
|
# Then protect markdown images
|
||||||
|
text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
|
||||||
|
|
||||||
# Now remove all remaining URLs
|
# Now remove all remaining URLs
|
||||||
url_pattern = r"https?://[^\s)]+"
|
url_pattern = r"https?://\S+"
|
||||||
text = re.sub(url_pattern, "", text)
|
text = re.sub(url_pattern, "", text)
|
||||||
|
|
||||||
# Finally, restore the Markdown image URLs
|
# Restore the Markdown links and images
|
||||||
for i, url in enumerate(placeholders):
|
for i, (link_type, text_or_alt, url) in enumerate(placeholders):
|
||||||
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
|
placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
|
||||||
|
if link_type == "link":
|
||||||
|
text = text.replace(placeholder, f"[{text_or_alt}]({url})")
|
||||||
|
else: # image
|
||||||
|
text = text.replace(placeholder, f"")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def filter_string(self, text):
|
def filter_string(self, text):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -13,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.embedding.retrieval import RetrievalSegments
|
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||||
from core.rag.entities.metadata_entities import MetadataCondition
|
from core.rag.entities.metadata_entities import MetadataCondition
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
|
|
@ -36,6 +37,8 @@ default_retrieval_model = {
|
||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RetrievalService:
|
class RetrievalService:
|
||||||
# Cache precompiled regular expressions to avoid repeated compilation
|
# Cache precompiled regular expressions to avoid repeated compilation
|
||||||
|
|
@ -106,7 +109,12 @@ class RetrievalService:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
|
if futures:
|
||||||
|
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
||||||
|
if exceptions:
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
break
|
||||||
|
|
||||||
if exceptions:
|
if exceptions:
|
||||||
raise ValueError(";\n".join(exceptions))
|
raise ValueError(";\n".join(exceptions))
|
||||||
|
|
@ -210,6 +218,7 @@ class RetrievalService:
|
||||||
)
|
)
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
exceptions.append(str(e))
|
exceptions.append(str(e))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -303,6 +312,7 @@ class RetrievalService:
|
||||||
else:
|
else:
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
exceptions.append(str(e))
|
exceptions.append(str(e))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -351,6 +361,7 @@ class RetrievalService:
|
||||||
else:
|
else:
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
exceptions.append(str(e))
|
exceptions.append(str(e))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -381,10 +392,9 @@ class RetrievalService:
|
||||||
records = []
|
records = []
|
||||||
include_segment_ids = set()
|
include_segment_ids = set()
|
||||||
segment_child_map = {}
|
segment_child_map = {}
|
||||||
segment_file_map = {}
|
|
||||||
|
|
||||||
valid_dataset_documents = {}
|
valid_dataset_documents = {}
|
||||||
image_doc_ids = []
|
image_doc_ids: list[Any] = []
|
||||||
child_index_node_ids = []
|
child_index_node_ids = []
|
||||||
index_node_ids = []
|
index_node_ids = []
|
||||||
doc_to_document_map = {}
|
doc_to_document_map = {}
|
||||||
|
|
@ -417,28 +427,39 @@ class RetrievalService:
|
||||||
child_index_node_ids = [i for i in child_index_node_ids if i]
|
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||||
index_node_ids = [i for i in index_node_ids if i]
|
index_node_ids = [i for i in index_node_ids if i]
|
||||||
|
|
||||||
segment_ids = []
|
segment_ids: list[str] = []
|
||||||
index_node_segments: list[DocumentSegment] = []
|
index_node_segments: list[DocumentSegment] = []
|
||||||
segments: list[DocumentSegment] = []
|
segments: list[DocumentSegment] = []
|
||||||
attachment_map = {}
|
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||||
child_chunk_map = {}
|
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||||
doc_segment_map = {}
|
doc_segment_map: dict[str, list[str]] = {}
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||||
|
|
||||||
for attachment in attachments:
|
for attachment in attachments:
|
||||||
segment_ids.append(attachment["segment_id"])
|
segment_ids.append(attachment["segment_id"])
|
||||||
attachment_map[attachment["segment_id"]] = attachment
|
if attachment["segment_id"] in attachment_map:
|
||||||
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
|
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
|
||||||
|
else:
|
||||||
|
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
|
||||||
|
if attachment["segment_id"] in doc_segment_map:
|
||||||
|
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||||
|
else:
|
||||||
|
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||||
|
|
||||||
for i in child_index_nodes:
|
for i in child_index_nodes:
|
||||||
segment_ids.append(i.segment_id)
|
segment_ids.append(i.segment_id)
|
||||||
child_chunk_map[i.segment_id] = i
|
if i.segment_id in child_chunk_map:
|
||||||
doc_segment_map[i.segment_id] = i.index_node_id
|
child_chunk_map[i.segment_id].append(i)
|
||||||
|
else:
|
||||||
|
child_chunk_map[i.segment_id] = [i]
|
||||||
|
if i.segment_id in doc_segment_map:
|
||||||
|
doc_segment_map[i.segment_id].append(i.index_node_id)
|
||||||
|
else:
|
||||||
|
doc_segment_map[i.segment_id] = [i.index_node_id]
|
||||||
|
|
||||||
if index_node_ids:
|
if index_node_ids:
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
|
|
@ -448,7 +469,7 @@ class RetrievalService:
|
||||||
)
|
)
|
||||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||||
for index_node_segment in index_node_segments:
|
for index_node_segment in index_node_segments:
|
||||||
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
|
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||||
if segment_ids:
|
if segment_ids:
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
DocumentSegment.enabled == True,
|
DocumentSegment.enabled == True,
|
||||||
|
|
@ -461,95 +482,86 @@ class RetrievalService:
|
||||||
segments.extend(index_node_segments)
|
segments.extend(index_node_segments)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
doc_id = doc_segment_map.get(segment.id)
|
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||||
child_chunk = child_chunk_map.get(segment.id)
|
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||||
attachment_info = attachment_map.get(segment.id)
|
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||||
|
|
||||||
if doc_id:
|
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
document = doc_to_document_map[doc_id]
|
if segment.id not in include_segment_ids:
|
||||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
|
include_segment_ids.add(segment.id)
|
||||||
document.metadata.get("document_id")
|
if child_chunks or attachment_infos:
|
||||||
)
|
child_chunk_details = []
|
||||||
|
max_score = 0.0
|
||||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
for child_chunk in child_chunks:
|
||||||
if segment.id not in include_segment_ids:
|
document = doc_to_document_map[child_chunk.index_node_id]
|
||||||
include_segment_ids.add(segment.id)
|
|
||||||
if child_chunk:
|
|
||||||
child_chunk_detail = {
|
child_chunk_detail = {
|
||||||
"id": child_chunk.id,
|
"id": child_chunk.id,
|
||||||
"content": child_chunk.content,
|
"content": child_chunk.content,
|
||||||
"position": child_chunk.position,
|
"position": child_chunk.position,
|
||||||
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||||
}
|
}
|
||||||
map_detail = {
|
child_chunk_details.append(child_chunk_detail)
|
||||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||||
"child_chunks": [child_chunk_detail],
|
for attachment_info in attachment_infos:
|
||||||
}
|
file_document = doc_to_document_map[attachment_info["id"]]
|
||||||
segment_child_map[segment.id] = map_detail
|
max_score = max(
|
||||||
record = {
|
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||||
"segment": segment,
|
)
|
||||||
|
|
||||||
|
map_detail = {
|
||||||
|
"max_score": max_score,
|
||||||
|
"child_chunks": child_chunk_details,
|
||||||
}
|
}
|
||||||
if attachment_info:
|
segment_child_map[segment.id] = map_detail
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
record: dict[str, Any] = {
|
||||||
records.append(record)
|
"segment": segment,
|
||||||
else:
|
}
|
||||||
if child_chunk:
|
records.append(record)
|
||||||
child_chunk_detail = {
|
else:
|
||||||
"id": child_chunk.id,
|
if segment.id not in include_segment_ids:
|
||||||
"content": child_chunk.content,
|
include_segment_ids.add(segment.id)
|
||||||
"position": child_chunk.position,
|
max_score = 0.0
|
||||||
"score": document.metadata.get("score", 0.0),
|
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||||
}
|
if segment_document:
|
||||||
if segment.id in segment_child_map:
|
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
|
for attachment_info in attachment_infos:
|
||||||
segment_child_map[segment.id]["max_score"] = max(
|
file_doc = doc_to_document_map.get(attachment_info["id"])
|
||||||
segment_child_map[segment.id]["max_score"],
|
if file_doc:
|
||||||
document.metadata.get("score", 0.0) if document else 0.0,
|
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||||
)
|
record = {
|
||||||
else:
|
"segment": segment,
|
||||||
segment_child_map[segment.id] = {
|
"score": max_score,
|
||||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
}
|
||||||
"child_chunks": [child_chunk_detail],
|
records.append(record)
|
||||||
}
|
|
||||||
if attachment_info:
|
|
||||||
if segment.id in segment_file_map:
|
|
||||||
segment_file_map[segment.id].append(attachment_info)
|
|
||||||
else:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
else:
|
|
||||||
if segment.id not in include_segment_ids:
|
|
||||||
include_segment_ids.add(segment.id)
|
|
||||||
record = {
|
|
||||||
"segment": segment,
|
|
||||||
"score": document.metadata.get("score", 0.0), # type: ignore
|
|
||||||
}
|
|
||||||
if attachment_info:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
records.append(record)
|
|
||||||
else:
|
|
||||||
if attachment_info:
|
|
||||||
attachment_infos = segment_file_map.get(segment.id, [])
|
|
||||||
if attachment_info not in attachment_infos:
|
|
||||||
attachment_infos.append(attachment_info)
|
|
||||||
segment_file_map[segment.id] = attachment_infos
|
|
||||||
|
|
||||||
# Add child chunks information to records
|
# Add child chunks information to records
|
||||||
for record in records:
|
for record in records:
|
||||||
if record["segment"].id in segment_child_map:
|
if record["segment"].id in segment_child_map:
|
||||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||||
if record["segment"].id in segment_file_map:
|
if record["segment"].id in attachment_map:
|
||||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||||
|
|
||||||
result = []
|
result: list[RetrievalSegments] = []
|
||||||
for record in records:
|
for record in records:
|
||||||
# Extract segment
|
# Extract segment
|
||||||
segment = record["segment"]
|
segment = record["segment"]
|
||||||
|
|
||||||
# Extract child_chunks, ensuring it's a list or None
|
# Extract child_chunks, ensuring it's a list or None
|
||||||
child_chunks = record.get("child_chunks")
|
raw_child_chunks = record.get("child_chunks")
|
||||||
if not isinstance(child_chunks, list):
|
child_chunks_list: list[RetrievalChildChunk] | None = None
|
||||||
child_chunks = None
|
if isinstance(raw_child_chunks, list):
|
||||||
|
# Sort by score descending
|
||||||
|
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
|
||||||
|
child_chunks_list = [
|
||||||
|
RetrievalChildChunk(
|
||||||
|
id=chunk["id"],
|
||||||
|
content=chunk["content"],
|
||||||
|
score=chunk.get("score", 0.0),
|
||||||
|
position=chunk["position"],
|
||||||
|
)
|
||||||
|
for chunk in sorted_chunks
|
||||||
|
]
|
||||||
|
|
||||||
# Extract files, ensuring it's a list or None
|
# Extract files, ensuring it's a list or None
|
||||||
files = record.get("files")
|
files = record.get("files")
|
||||||
|
|
@ -566,11 +578,11 @@ class RetrievalService:
|
||||||
|
|
||||||
# Create RetrievalSegments object
|
# Create RetrievalSegments object
|
||||||
retrieval_segment = RetrievalSegments(
|
retrieval_segment = RetrievalSegments(
|
||||||
segment=segment, child_chunks=child_chunks, score=score, files=files
|
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||||
)
|
)
|
||||||
result.append(retrieval_segment)
|
result.append(retrieval_segment)
|
||||||
|
|
||||||
return result
|
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
@ -662,7 +674,14 @@ class RetrievalService:
|
||||||
document_ids_filter=document_ids_filter,
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
|
# Use as_completed for early error propagation - cancel remaining futures on first error
|
||||||
|
if futures:
|
||||||
|
for future in concurrent.futures.as_completed(futures, timeout=300):
|
||||||
|
if future.exception():
|
||||||
|
# Cancel remaining futures to avoid unnecessary waiting
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
break
|
||||||
|
|
||||||
if exceptions:
|
if exceptions:
|
||||||
raise ValueError(";\n".join(exceptions))
|
raise ValueError(";\n".join(exceptions))
|
||||||
|
|
|
||||||
|
|
@ -255,7 +255,10 @@ class PGVector(BaseVector):
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
if not cur.fetchone():
|
||||||
|
cur.execute("CREATE EXTENSION vector")
|
||||||
|
|
||||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||||
# PG hnsw index only support 2000 dimension or less
|
# PG hnsw index only support 2000 dimension or less
|
||||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class ExtractProcessor:
|
||||||
if file_extension in {".xlsx", ".xls"}:
|
if file_extension in {".xlsx", ".xls"}:
|
||||||
extractor = ExcelExtractor(file_path)
|
extractor = ExcelExtractor(file_path)
|
||||||
elif file_extension == ".pdf":
|
elif file_extension == ".pdf":
|
||||||
extractor = PdfExtractor(file_path)
|
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||||
extractor = (
|
extractor = (
|
||||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||||
|
|
@ -148,7 +148,7 @@ class ExtractProcessor:
|
||||||
if file_extension in {".xlsx", ".xls"}:
|
if file_extension in {".xlsx", ".xls"}:
|
||||||
extractor = ExcelExtractor(file_path)
|
extractor = ExcelExtractor(file_path)
|
||||||
elif file_extension == ".pdf":
|
elif file_extension == ".pdf":
|
||||||
extractor = PdfExtractor(file_path)
|
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||||
elif file_extension in {".htm", ".html"}:
|
elif file_extension in {".htm", ".html"}:
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,57 @@
|
||||||
"""Abstract interface for document loader implementations."""
|
"""Abstract interface for document loader implementations."""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
import pypdfium2
|
||||||
|
import pypdfium2.raw as pdfium_c
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.rag.extractor.blob.blob import Blob
|
from core.rag.extractor.blob.blob import Blob
|
||||||
from core.rag.extractor.extractor_base import BaseExtractor
|
from core.rag.extractor.extractor_base import BaseExtractor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PdfExtractor(BaseExtractor):
|
class PdfExtractor(BaseExtractor):
|
||||||
"""Load pdf files.
|
"""
|
||||||
|
PdfExtractor is used to extract text and images from PDF files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the file to load.
|
file_path: Path to the PDF file.
|
||||||
|
tenant_id: Workspace ID.
|
||||||
|
user_id: ID of the user performing the extraction.
|
||||||
|
file_cache_key: Optional cache key for the extracted text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file_path: str, file_cache_key: str | None = None):
|
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
|
||||||
"""Initialize with file path."""
|
IMAGE_FORMATS = [
|
||||||
|
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
|
||||||
|
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
|
||||||
|
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
|
||||||
|
(b"GIF8", "gif", "image/gif"),
|
||||||
|
(b"BM", "bmp", "image/bmp"),
|
||||||
|
(b"II*\x00", "tiff", "image/tiff"),
|
||||||
|
(b"MM\x00*", "tiff", "image/tiff"),
|
||||||
|
(b"II+\x00", "tiff", "image/tiff"),
|
||||||
|
(b"MM\x00+", "tiff", "image/tiff"),
|
||||||
|
]
|
||||||
|
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
|
||||||
|
|
||||||
|
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
|
||||||
|
"""Initialize PdfExtractor."""
|
||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
|
self._tenant_id = tenant_id
|
||||||
|
self._user_id = user_id
|
||||||
self._file_cache_key = file_cache_key
|
self._file_cache_key = file_cache_key
|
||||||
|
|
||||||
def extract(self) -> list[Document]:
|
def extract(self) -> list[Document]:
|
||||||
|
|
@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
|
||||||
|
|
||||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazily parse the blob."""
|
"""Lazily parse the blob."""
|
||||||
import pypdfium2 # type: ignore
|
|
||||||
|
|
||||||
with blob.as_bytes_io() as file_path:
|
with blob.as_bytes_io() as file_path:
|
||||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||||
|
|
@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
|
||||||
text_page = page.get_textpage()
|
text_page = page.get_textpage()
|
||||||
content = text_page.get_text_range()
|
content = text_page.get_text_range()
|
||||||
text_page.close()
|
text_page.close()
|
||||||
|
|
||||||
|
image_content = self._extract_images(page)
|
||||||
|
if image_content:
|
||||||
|
content += "\n" + image_content
|
||||||
|
|
||||||
page.close()
|
page.close()
|
||||||
metadata = {"source": blob.source, "page": page_number}
|
metadata = {"source": blob.source, "page": page_number}
|
||||||
yield Document(page_content=content, metadata=metadata)
|
yield Document(page_content=content, metadata=metadata)
|
||||||
finally:
|
finally:
|
||||||
pdf_reader.close()
|
pdf_reader.close()
|
||||||
|
|
||||||
|
def _extract_images(self, page) -> str:
|
||||||
|
"""
|
||||||
|
Extract images from a PDF page, save them to storage and database,
|
||||||
|
and return markdown image links.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: pypdfium2 page object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown string containing links to the extracted images.
|
||||||
|
"""
|
||||||
|
image_content = []
|
||||||
|
upload_files = []
|
||||||
|
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
|
||||||
|
for obj in image_objects:
|
||||||
|
try:
|
||||||
|
# Extract image bytes
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
|
||||||
|
# Fallback to png for other formats
|
||||||
|
obj.extract(img_byte_arr, fb_format="png")
|
||||||
|
img_bytes = img_byte_arr.getvalue()
|
||||||
|
|
||||||
|
if not img_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
header = img_bytes[: self.MAX_MAGIC_LEN]
|
||||||
|
image_ext = None
|
||||||
|
mime_type = None
|
||||||
|
for magic, ext, mime in self.IMAGE_FORMATS:
|
||||||
|
if header.startswith(magic):
|
||||||
|
image_ext = ext
|
||||||
|
mime_type = mime
|
||||||
|
break
|
||||||
|
|
||||||
|
if not image_ext or not mime_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_uuid = str(uuid.uuid4())
|
||||||
|
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
|
||||||
|
|
||||||
|
storage.save(file_key, img_bytes)
|
||||||
|
|
||||||
|
# save file to db
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=self._tenant_id,
|
||||||
|
storage_type=dify_config.STORAGE_TYPE,
|
||||||
|
key=file_key,
|
||||||
|
name=file_key,
|
||||||
|
size=len(img_bytes),
|
||||||
|
extension=image_ext,
|
||||||
|
mime_type=mime_type,
|
||||||
|
created_by=self._user_id,
|
||||||
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
used=True,
|
||||||
|
used_by=self._user_id,
|
||||||
|
used_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
upload_files.append(upload_file)
|
||||||
|
image_content.append(f"")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to extract image from PDF: %s", e)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to get objects from PDF page: %s", e)
|
||||||
|
if upload_files:
|
||||||
|
db.session.add_all(upload_files)
|
||||||
|
db.session.commit()
|
||||||
|
return "\n".join(image_content)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from sqlalchemy import and_, or_, select
|
from sqlalchemy import and_, literal, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.entities import (
|
from core.app.app_config.entities import (
|
||||||
|
|
@ -516,6 +516,9 @@ class DatasetRetrieval:
|
||||||
].embedding_model_provider
|
].embedding_model_provider
|
||||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||||
with measure_time() as timer:
|
with measure_time() as timer:
|
||||||
|
cancel_event = threading.Event()
|
||||||
|
thread_exceptions: list[Exception] = []
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
query_thread = threading.Thread(
|
query_thread = threading.Thread(
|
||||||
target=self._multiple_retrieve_thread,
|
target=self._multiple_retrieve_thread,
|
||||||
|
|
@ -534,6 +537,8 @@ class DatasetRetrieval:
|
||||||
"score_threshold": score_threshold,
|
"score_threshold": score_threshold,
|
||||||
"query": query,
|
"query": query,
|
||||||
"attachment_id": None,
|
"attachment_id": None,
|
||||||
|
"cancel_event": cancel_event,
|
||||||
|
"thread_exceptions": thread_exceptions,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
all_threads.append(query_thread)
|
all_threads.append(query_thread)
|
||||||
|
|
@ -557,12 +562,25 @@ class DatasetRetrieval:
|
||||||
"score_threshold": score_threshold,
|
"score_threshold": score_threshold,
|
||||||
"query": None,
|
"query": None,
|
||||||
"attachment_id": attachment_id,
|
"attachment_id": attachment_id,
|
||||||
|
"cancel_event": cancel_event,
|
||||||
|
"thread_exceptions": thread_exceptions,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
all_threads.append(attachment_thread)
|
all_threads.append(attachment_thread)
|
||||||
attachment_thread.start()
|
attachment_thread.start()
|
||||||
for thread in all_threads:
|
|
||||||
thread.join()
|
# Poll threads with short timeout to detect errors quickly (fail-fast)
|
||||||
|
while any(t.is_alive() for t in all_threads):
|
||||||
|
for thread in all_threads:
|
||||||
|
thread.join(timeout=0.1)
|
||||||
|
if thread_exceptions:
|
||||||
|
cancel_event.set()
|
||||||
|
break
|
||||||
|
if thread_exceptions:
|
||||||
|
break
|
||||||
|
|
||||||
|
if thread_exceptions:
|
||||||
|
raise thread_exceptions[0]
|
||||||
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
||||||
|
|
||||||
if all_documents:
|
if all_documents:
|
||||||
|
|
@ -1036,7 +1054,7 @@ class DatasetRetrieval:
|
||||||
if automatic_metadata_filters:
|
if automatic_metadata_filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||||
self._process_metadata_filter_func(
|
self.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
filter.get("condition"), # type: ignore
|
filter.get("condition"), # type: ignore
|
||||||
filter.get("metadata_name"), # type: ignore
|
filter.get("metadata_name"), # type: ignore
|
||||||
|
|
@ -1072,7 +1090,7 @@ class DatasetRetrieval:
|
||||||
value=expected_value,
|
value=expected_value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
filters = self._process_metadata_filter_func(
|
filters = self.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
condition.comparison_operator,
|
condition.comparison_operator,
|
||||||
metadata_name,
|
metadata_name,
|
||||||
|
|
@ -1168,8 +1186,9 @@ class DatasetRetrieval:
|
||||||
return None
|
return None
|
||||||
return automatic_metadata_filters
|
return automatic_metadata_filters
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
@classmethod
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
def process_metadata_filter_func(
|
||||||
|
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||||
):
|
):
|
||||||
if value is None and condition not in ("empty", "not empty"):
|
if value is None and condition not in ("empty", "not empty"):
|
||||||
return filters
|
return filters
|
||||||
|
|
@ -1218,6 +1237,20 @@ class DatasetRetrieval:
|
||||||
|
|
||||||
case "≥" | ">=":
|
case "≥" | ">=":
|
||||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||||
|
case "in" | "not in":
|
||||||
|
if isinstance(value, str):
|
||||||
|
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
value_list = [str(v) for v in value if v is not None]
|
||||||
|
else:
|
||||||
|
value_list = [str(value)] if value is not None else []
|
||||||
|
|
||||||
|
if not value_list:
|
||||||
|
# `field in []` is False, `field not in []` is True
|
||||||
|
filters.append(literal(condition == "not in"))
|
||||||
|
else:
|
||||||
|
op = json_field.in_ if condition == "in" else json_field.notin_
|
||||||
|
filters.append(op(value_list))
|
||||||
case _:
|
case _:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -1389,40 +1422,53 @@ class DatasetRetrieval:
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
query: str | None,
|
query: str | None,
|
||||||
attachment_id: str | None,
|
attachment_id: str | None,
|
||||||
|
cancel_event: threading.Event | None = None,
|
||||||
|
thread_exceptions: list[Exception] | None = None,
|
||||||
):
|
):
|
||||||
with flask_app.app_context():
|
try:
|
||||||
threads = []
|
with flask_app.app_context():
|
||||||
all_documents_item: list[Document] = []
|
threads = []
|
||||||
index_type = None
|
all_documents_item: list[Document] = []
|
||||||
for dataset in available_datasets:
|
index_type = None
|
||||||
index_type = dataset.indexing_technique
|
for dataset in available_datasets:
|
||||||
document_ids_filter = None
|
# Check for cancellation signal
|
||||||
if dataset.provider != "external":
|
if cancel_event and cancel_event.is_set():
|
||||||
if metadata_condition and not metadata_filter_document_ids:
|
break
|
||||||
continue
|
index_type = dataset.indexing_technique
|
||||||
if metadata_filter_document_ids:
|
document_ids_filter = None
|
||||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
if dataset.provider != "external":
|
||||||
if document_ids:
|
if metadata_condition and not metadata_filter_document_ids:
|
||||||
document_ids_filter = document_ids
|
|
||||||
else:
|
|
||||||
continue
|
continue
|
||||||
retrieval_thread = threading.Thread(
|
if metadata_filter_document_ids:
|
||||||
target=self._retriever,
|
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||||
kwargs={
|
if document_ids:
|
||||||
"flask_app": flask_app,
|
document_ids_filter = document_ids
|
||||||
"dataset_id": dataset.id,
|
else:
|
||||||
"query": query,
|
continue
|
||||||
"top_k": top_k,
|
retrieval_thread = threading.Thread(
|
||||||
"all_documents": all_documents_item,
|
target=self._retriever,
|
||||||
"document_ids_filter": document_ids_filter,
|
kwargs={
|
||||||
"metadata_condition": metadata_condition,
|
"flask_app": flask_app,
|
||||||
"attachment_ids": [attachment_id] if attachment_id else None,
|
"dataset_id": dataset.id,
|
||||||
},
|
"query": query,
|
||||||
)
|
"top_k": top_k,
|
||||||
threads.append(retrieval_thread)
|
"all_documents": all_documents_item,
|
||||||
retrieval_thread.start()
|
"document_ids_filter": document_ids_filter,
|
||||||
for thread in threads:
|
"metadata_condition": metadata_condition,
|
||||||
thread.join()
|
"attachment_ids": [attachment_id] if attachment_id else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
threads.append(retrieval_thread)
|
||||||
|
retrieval_thread.start()
|
||||||
|
|
||||||
|
# Poll threads with short timeout to respond quickly to cancellation
|
||||||
|
while any(t.is_alive() for t in threads):
|
||||||
|
for thread in threads:
|
||||||
|
thread.join(timeout=0.1)
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
break
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
if reranking_enable:
|
if reranking_enable:
|
||||||
# do rerank for searched documents
|
# do rerank for searched documents
|
||||||
|
|
@ -1455,3 +1501,8 @@ class DatasetRetrieval:
|
||||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||||
if all_documents_item:
|
if all_documents_item:
|
||||||
all_documents.extend(all_documents_item)
|
all_documents.extend(all_documents_item)
|
||||||
|
except Exception as e:
|
||||||
|
if cancel_event:
|
||||||
|
cancel_event.set()
|
||||||
|
if thread_exceptions is not None:
|
||||||
|
thread_exceptions.append(e)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,15 @@ from typing import Any
|
||||||
|
|
||||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPConnectionError
|
from core.mcp.error import MCPConnectionError
|
||||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
from core.mcp.types import (
|
||||||
|
AudioContent,
|
||||||
|
BlobResourceContents,
|
||||||
|
CallToolResult,
|
||||||
|
EmbeddedResource,
|
||||||
|
ImageContent,
|
||||||
|
TextContent,
|
||||||
|
TextResourceContents,
|
||||||
|
)
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||||
|
|
@ -53,10 +61,19 @@ class MCPTool(Tool):
|
||||||
for content in result.content:
|
for content in result.content:
|
||||||
if isinstance(content, TextContent):
|
if isinstance(content, TextContent):
|
||||||
yield from self._process_text_content(content)
|
yield from self._process_text_content(content)
|
||||||
elif isinstance(content, ImageContent):
|
elif isinstance(content, ImageContent | AudioContent):
|
||||||
yield self._process_image_content(content)
|
yield self.create_blob_message(
|
||||||
elif isinstance(content, AudioContent):
|
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
|
||||||
yield self._process_audio_content(content)
|
)
|
||||||
|
elif isinstance(content, EmbeddedResource):
|
||||||
|
resource = content.resource
|
||||||
|
if isinstance(resource, TextResourceContents):
|
||||||
|
yield self.create_text_message(resource.text)
|
||||||
|
elif isinstance(resource, BlobResourceContents):
|
||||||
|
mime_type = resource.mimeType or "application/octet-stream"
|
||||||
|
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
|
||||||
|
else:
|
||||||
|
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
|
||||||
else:
|
else:
|
||||||
logger.warning("Unsupported content type=%s", type(content))
|
logger.warning("Unsupported content type=%s", type(content))
|
||||||
|
|
||||||
|
|
@ -101,14 +118,6 @@ class MCPTool(Tool):
|
||||||
for item in json_list:
|
for item in json_list:
|
||||||
yield self.create_json_message(item)
|
yield self.create_json_message(item)
|
||||||
|
|
||||||
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
|
|
||||||
"""Process image content and return a blob message."""
|
|
||||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
|
||||||
|
|
||||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
|
||||||
"""Process audio content and return a blob message."""
|
|
||||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||||
return MCPTool(
|
return MCPTool(
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
|
|
|
||||||
|
|
@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def auto_parse_to_tool_bundle(
|
def auto_parse_to_tool_bundle(
|
||||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||||
) -> tuple[list[ApiToolBundle], str]:
|
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||||
"""
|
"""
|
||||||
auto parse to tool bundle
|
auto parse to tool bundle
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import re
|
||||||
def remove_leading_symbols(text: str) -> str:
|
def remove_leading_symbols(text: str) -> str:
|
||||||
"""
|
"""
|
||||||
Remove leading punctuation or symbols from the given text.
|
Remove leading punctuation or symbols from the given text.
|
||||||
|
Preserves markdown links like [text](url) at the start.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): The input text to process.
|
text (str): The input text to process.
|
||||||
|
|
@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
|
||||||
Returns:
|
Returns:
|
||||||
str: The text with leading punctuation or symbols removed.
|
str: The text with leading punctuation or symbols removed.
|
||||||
"""
|
"""
|
||||||
|
# Check if text starts with a markdown link - preserve it
|
||||||
|
markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
|
||||||
|
if re.match(markdown_link_pattern, text):
|
||||||
|
return text
|
||||||
|
|
||||||
# Match Unicode ranges for punctuation and symbols
|
# Match Unicode ranges for punctuation and symbols
|
||||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.plugin.entities.parameters import PluginParameterOption
|
from core.plugin.entities.parameters import PluginParameterOption
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
|
|
@ -47,33 +48,29 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
|
app = session.get(App, db_provider.app_id)
|
||||||
if not provider:
|
|
||||||
raise ValueError("workflow provider not found")
|
|
||||||
app = session.get(App, provider.app_id)
|
|
||||||
if not app:
|
if not app:
|
||||||
raise ValueError("app not found")
|
raise ValueError("app not found")
|
||||||
|
|
||||||
user = session.get(Account, provider.user_id) if provider.user_id else None
|
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||||
|
|
||||||
controller = WorkflowToolProviderController(
|
controller = WorkflowToolProviderController(
|
||||||
entity=ToolProviderEntity(
|
entity=ToolProviderEntity(
|
||||||
identity=ToolProviderIdentity(
|
identity=ToolProviderIdentity(
|
||||||
author=user.name if user else "",
|
author=user.name if user else "",
|
||||||
name=provider.label,
|
name=db_provider.label,
|
||||||
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
|
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||||
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
|
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||||
icon=provider.icon,
|
icon=db_provider.icon,
|
||||||
),
|
),
|
||||||
credentials_schema=[],
|
credentials_schema=[],
|
||||||
plugin_id=None,
|
plugin_id=None,
|
||||||
),
|
),
|
||||||
provider_id=provider.id or "",
|
provider_id=db_provider.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
controller.tools = [
|
controller.tools = [
|
||||||
controller._get_db_provider_tool(provider, app, session=session, user=user),
|
controller._get_db_provider_tool(db_provider, app, session=session, user=user),
|
||||||
]
|
]
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ class SkipPropagator:
|
||||||
if edge_states["has_taken"]:
|
if edge_states["has_taken"]:
|
||||||
# Enqueue node
|
# Enqueue node
|
||||||
self._state_manager.enqueue_node(downstream_node_id)
|
self._state_manager.enqueue_node(downstream_node_id)
|
||||||
|
self._state_manager.start_execution(downstream_node_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
# All edges are skipped, propagate skip to this node
|
# All edges are skipped, propagate skip to this node
|
||||||
|
|
|
||||||
|
|
@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
|
||||||
self.expected_type = expected_type
|
self.expected_type = expected_type
|
||||||
self.actual_type = actual_type
|
self.actual_type = actual_type
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMaxIterationError(AgentNodeError):
|
||||||
|
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||||
|
|
||||||
|
def __init__(self, max_iteration: int):
|
||||||
|
self.max_iteration = max_iteration
|
||||||
|
super().__init__(
|
||||||
|
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||||
|
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from sqlalchemy import and_, func, literal, or_, select
|
from sqlalchemy import and_, func, or_, select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||||
|
|
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
if automatic_metadata_filters:
|
if automatic_metadata_filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||||
self._process_metadata_filter_func(
|
DatasetRetrieval.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
filter.get("condition", ""),
|
filter.get("condition", ""),
|
||||||
filter.get("metadata_name", ""),
|
filter.get("metadata_name", ""),
|
||||||
|
|
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
value=expected_value,
|
value=expected_value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
filters = self._process_metadata_filter_func(
|
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
condition.comparison_operator,
|
condition.comparison_operator,
|
||||||
metadata_name,
|
metadata_name,
|
||||||
|
|
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
return [], usage
|
return [], usage
|
||||||
return automatic_metadata_filters, usage
|
return automatic_metadata_filters, usage
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
|
||||||
) -> list[Any]:
|
|
||||||
if value is None and condition not in ("empty", "not empty"):
|
|
||||||
return filters
|
|
||||||
|
|
||||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
|
||||||
|
|
||||||
match condition:
|
|
||||||
case "contains":
|
|
||||||
filters.append(json_field.like(f"%{value}%"))
|
|
||||||
|
|
||||||
case "not contains":
|
|
||||||
filters.append(json_field.notlike(f"%{value}%"))
|
|
||||||
|
|
||||||
case "start with":
|
|
||||||
filters.append(json_field.like(f"{value}%"))
|
|
||||||
|
|
||||||
case "end with":
|
|
||||||
filters.append(json_field.like(f"%{value}"))
|
|
||||||
case "in":
|
|
||||||
if isinstance(value, str):
|
|
||||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
|
||||||
elif isinstance(value, (list, tuple)):
|
|
||||||
value_list = [str(v) for v in value if v is not None]
|
|
||||||
else:
|
|
||||||
value_list = [str(value)] if value is not None else []
|
|
||||||
|
|
||||||
if not value_list:
|
|
||||||
filters.append(literal(False))
|
|
||||||
else:
|
|
||||||
filters.append(json_field.in_(value_list))
|
|
||||||
|
|
||||||
case "not in":
|
|
||||||
if isinstance(value, str):
|
|
||||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
|
||||||
elif isinstance(value, (list, tuple)):
|
|
||||||
value_list = [str(v) for v in value if v is not None]
|
|
||||||
else:
|
|
||||||
value_list = [str(value)] if value is not None else []
|
|
||||||
|
|
||||||
if not value_list:
|
|
||||||
filters.append(literal(True))
|
|
||||||
else:
|
|
||||||
filters.append(json_field.notin_(value_list))
|
|
||||||
|
|
||||||
case "is" | "=":
|
|
||||||
if isinstance(value, str):
|
|
||||||
filters.append(json_field == value)
|
|
||||||
elif isinstance(value, (int, float)):
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
|
||||||
|
|
||||||
case "is not" | "≠":
|
|
||||||
if isinstance(value, str):
|
|
||||||
filters.append(json_field != value)
|
|
||||||
elif isinstance(value, (int, float)):
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
|
||||||
|
|
||||||
case "empty":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
|
||||||
|
|
||||||
case "not empty":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
|
||||||
|
|
||||||
case "before" | "<":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
|
||||||
|
|
||||||
case "after" | ">":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
|
||||||
|
|
||||||
case "≤" | "<=":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
|
||||||
|
|
||||||
case "≥" | ">=":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
|
||||||
|
|
||||||
case _:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return filters
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,8 @@ from dify_app import DifyApp
|
||||||
|
|
||||||
def _get_celery_ssl_options() -> dict[str, Any] | None:
|
def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||||
"""Get SSL configuration for Celery broker/backend connections."""
|
"""Get SSL configuration for Celery broker/backend connections."""
|
||||||
# Use REDIS_USE_SSL for consistency with the main Redis client
|
|
||||||
# Only apply SSL if we're using Redis as broker/backend
|
# Only apply SSL if we're using Redis as broker/backend
|
||||||
if not dify_config.REDIS_USE_SSL:
|
if not dify_config.BROKER_USE_SSL:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if Celery is actually using Redis
|
# Check if Celery is actually using Redis
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
|
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
|
||||||
config = CosConfig(
|
if dify_config.TENCENT_COS_CUSTOM_DOMAIN:
|
||||||
Region=dify_config.TENCENT_COS_REGION,
|
config = CosConfig(
|
||||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN,
|
||||||
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||||
Scheme=dify_config.TENCENT_COS_SCHEME,
|
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||||
)
|
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = CosConfig(
|
||||||
|
Region=dify_config.TENCENT_COS_REGION,
|
||||||
|
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||||
|
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||||
|
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||||
|
)
|
||||||
self.client = CosS3Client(config)
|
self.client = CosS3Client(config)
|
||||||
|
|
||||||
def save(self, filename, data):
|
def save(self, filename, data):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
|
@ -12,7 +12,7 @@ annotation_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_annotation_model(api_or_ns: Api | Namespace):
|
def build_annotation_model(api_or_ns: Namespace):
|
||||||
"""Build the annotation model for the API or Namespace."""
|
"""Build the annotation model for the API or Namespace."""
|
||||||
return api_or_ns.model("Annotation", annotation_fields)
|
return api_or_ns.model("Annotation", annotation_fields)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from fields.member_fields import simple_account_fields
|
from fields.member_fields import simple_account_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
@ -46,7 +46,7 @@ message_file_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_message_file_model(api_or_ns: Api | Namespace):
|
def build_message_file_model(api_or_ns: Namespace):
|
||||||
"""Build the message file fields for the API or Namespace."""
|
"""Build the message file fields for the API or Namespace."""
|
||||||
return api_or_ns.model("MessageFile", message_file_fields)
|
return api_or_ns.model("MessageFile", message_file_fields)
|
||||||
|
|
||||||
|
|
@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||||
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
||||||
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
||||||
|
|
||||||
|
|
@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa
|
||||||
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_delete_model(api_or_ns: Api | Namespace):
|
def build_conversation_delete_model(api_or_ns: Namespace):
|
||||||
"""Build the conversation delete model for the API or Namespace."""
|
"""Build the conversation delete model for the API or Namespace."""
|
||||||
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
||||||
|
|
||||||
|
|
||||||
def build_simple_conversation_model(api_or_ns: Api | Namespace):
|
def build_simple_conversation_model(api_or_ns: Namespace):
|
||||||
"""Build the simple conversation model for the API or Namespace."""
|
"""Build the simple conversation model for the API or Namespace."""
|
||||||
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
|
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_variable_model(api_or_ns: Api | Namespace):
|
def build_conversation_variable_model(api_or_ns: Namespace):
|
||||||
"""Build the conversation variable model for the API or Namespace."""
|
"""Build the conversation variable model for the API or Namespace."""
|
||||||
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
|
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||||
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
|
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
|
||||||
# Build the nested variable model first
|
# Build the nested variable model first
|
||||||
conversation_variable_model = build_conversation_variable_model(api_or_ns)
|
conversation_variable_model = build_conversation_variable_model(api_or_ns)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
simple_end_user_fields = {
|
simple_end_user_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
|
|
@ -8,5 +8,5 @@ simple_end_user_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_simple_end_user_model(api_or_ns: Api | Namespace):
|
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
|
@ -14,7 +14,7 @@ upload_config_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_upload_config_model(api_or_ns: Api | Namespace):
|
def build_upload_config_model(api_or_ns: Namespace):
|
||||||
"""Build the upload config model for the API or Namespace.
|
"""Build the upload config model for the API or Namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -39,7 +39,7 @@ file_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_file_model(api_or_ns: Api | Namespace):
|
def build_file_model(api_or_ns: Namespace):
|
||||||
"""Build the file model for the API or Namespace.
|
"""Build the file model for the API or Namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -57,7 +57,7 @@ remote_file_info_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_remote_file_info_model(api_or_ns: Api | Namespace):
|
def build_remote_file_info_model(api_or_ns: Namespace):
|
||||||
"""Build the remote file info model for the API or Namespace.
|
"""Build the remote file info model for the API or Namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
|
def build_file_with_signed_url_model(api_or_ns: Namespace):
|
||||||
"""Build the file with signed URL model for the API or Namespace.
|
"""Build the file with signed URL model for the API or Namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from libs.helper import AvatarUrlField, TimestampField
|
from libs.helper import AvatarUrlField, TimestampField
|
||||||
|
|
||||||
|
|
@ -9,7 +9,7 @@ simple_account_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_simple_account_model(api_or_ns: Api | Namespace):
|
def build_simple_account_model(api_or_ns: Namespace):
|
||||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
@ -10,7 +10,7 @@ feedback_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_feedback_model(api_or_ns: Api | Namespace):
|
def build_feedback_model(api_or_ns: Namespace):
|
||||||
"""Build the feedback model for the API or Namespace."""
|
"""Build the feedback model for the API or Namespace."""
|
||||||
return api_or_ns.model("Feedback", feedback_fields)
|
return api_or_ns.model("Feedback", feedback_fields)
|
||||||
|
|
||||||
|
|
@ -30,7 +30,7 @@ agent_thought_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_agent_thought_model(api_or_ns: Api | Namespace):
|
def build_agent_thought_model(api_or_ns: Namespace):
|
||||||
"""Build the agent thought model for the API or Namespace."""
|
"""Build the agent thought model for the API or Namespace."""
|
||||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
dataset_tag_fields = {
|
dataset_tag_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
|
|
@ -8,5 +8,5 @@ dataset_tag_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
|
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||||
|
|
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
|
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||||
"""Build the workflow app log partial model for the API or Namespace."""
|
"""Build the workflow app log partial model for the API or Namespace."""
|
||||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||||
simple_account_model = build_simple_account_model(api_or_ns)
|
simple_account_model = build_simple_account_model(api_or_ns)
|
||||||
|
|
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
|
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
|
||||||
"""Build the workflow app log pagination model for the API or Namespace."""
|
"""Build the workflow app log pagination model for the API or Namespace."""
|
||||||
# Build the nested partial model first
|
# Build the nested partial model first
|
||||||
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
|
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from fields.end_user_fields import simple_end_user_fields
|
from fields.end_user_fields import simple_end_user_fields
|
||||||
from fields.member_fields import simple_account_fields
|
from fields.member_fields import simple_account_fields
|
||||||
|
|
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
|
def build_workflow_run_for_log_model(api_or_ns: Namespace):
|
||||||
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,347 @@
|
||||||
|
"""
|
||||||
|
Archive Storage Client for S3-compatible storage.
|
||||||
|
|
||||||
|
This module provides a dedicated storage client for archiving or exporting logs
|
||||||
|
to S3-compatible object storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import gzip
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import orjson
|
||||||
|
from botocore.client import Config
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveStorageError(Exception):
|
||||||
|
"""Base exception for archive storage operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveStorageNotConfiguredError(ArchiveStorageError):
|
||||||
|
"""Raised when archive storage is not properly configured."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveStorage:
|
||||||
|
"""
|
||||||
|
S3-compatible storage client for archiving or exporting.
|
||||||
|
|
||||||
|
This client provides methods for storing and retrieving archived data in JSONL+gzip format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bucket: str):
|
||||||
|
if not dify_config.ARCHIVE_STORAGE_ENABLED:
|
||||||
|
raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
|
||||||
|
|
||||||
|
if not bucket:
|
||||||
|
raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
|
||||||
|
if not all(
|
||||||
|
[
|
||||||
|
dify_config.ARCHIVE_STORAGE_ENDPOINT,
|
||||||
|
bucket,
|
||||||
|
dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||||
|
dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise ArchiveStorageNotConfiguredError(
|
||||||
|
"Archive storage configuration is incomplete. "
|
||||||
|
"Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
|
||||||
|
"ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bucket = bucket
|
||||||
|
self.client = boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
|
||||||
|
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||||
|
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||||
|
region_name=dify_config.ARCHIVE_STORAGE_REGION,
|
||||||
|
config=Config(s3={"addressing_style": "path"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify bucket accessibility
|
||||||
|
try:
|
||||||
|
self.client.head_bucket(Bucket=self.bucket)
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response.get("Error", {}).get("Code")
|
||||||
|
if error_code == "404":
|
||||||
|
raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
|
||||||
|
elif error_code == "403":
|
||||||
|
raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
|
||||||
|
else:
|
||||||
|
raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
|
||||||
|
|
||||||
|
def put_object(self, key: str, data: bytes) -> str:
|
||||||
|
"""
|
||||||
|
Upload an object to the archive storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Object key (path) within the bucket
|
||||||
|
data: Binary data to upload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MD5 checksum of the uploaded data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ArchiveStorageError: If upload fails
|
||||||
|
"""
|
||||||
|
checksum = hashlib.md5(data).hexdigest()
|
||||||
|
try:
|
||||||
|
self.client.put_object(
|
||||||
|
Bucket=self.bucket,
|
||||||
|
Key=key,
|
||||||
|
Body=data,
|
||||||
|
ContentMD5=self._content_md5(data),
|
||||||
|
)
|
||||||
|
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
|
||||||
|
return checksum
|
||||||
|
except ClientError as e:
|
||||||
|
raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
|
||||||
|
|
||||||
|
def get_object(self, key: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Download an object from the archive storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Object key (path) within the bucket
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary data of the object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ArchiveStorageError: If download fails
|
||||||
|
FileNotFoundError: If object does not exist
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.get_object(Bucket=self.bucket, Key=key)
|
||||||
|
return response["Body"].read()
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response.get("Error", {}).get("Code")
|
||||||
|
if error_code == "NoSuchKey":
|
||||||
|
raise FileNotFoundError(f"Archive object not found: {key}")
|
||||||
|
raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
|
||||||
|
|
||||||
|
def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
|
||||||
|
"""
|
||||||
|
Stream an object from the archive storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Object key (path) within the bucket
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Chunks of binary data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ArchiveStorageError: If download fails
|
||||||
|
FileNotFoundError: If object does not exist
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.get_object(Bucket=self.bucket, Key=key)
|
||||||
|
yield from response["Body"].iter_chunks()
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response.get("Error", {}).get("Code")
|
||||||
|
if error_code == "NoSuchKey":
|
||||||
|
raise FileNotFoundError(f"Archive object not found: {key}")
|
||||||
|
raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
|
||||||
|
|
||||||
|
def object_exists(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an object exists in the archive storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Object key (path) within the bucket
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if object exists, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.head_object(Bucket=self.bucket, Key=key)
|
||||||
|
return True
|
||||||
|
except ClientError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_object(self, key: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete an object from the archive storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Object key (path) within the bucket
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ArchiveStorageError: If deletion fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.delete_object(Bucket=self.bucket, Key=key)
|
||||||
|
logger.debug("Deleted object: %s", key)
|
||||||
|
except ClientError as e:
|
||||||
|
raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
|
||||||
|
|
||||||
|
def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Generate a pre-signed URL for downloading an object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Object key (path) within the bucket
|
||||||
|
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pre-signed URL string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ArchiveStorageError: If generation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.client.generate_presigned_url(
|
||||||
|
ClientMethod="get_object",
|
||||||
|
Params={"Bucket": self.bucket, "Key": key},
|
||||||
|
ExpiresIn=expires_in,
|
||||||
|
)
|
||||||
|
except ClientError as e:
|
||||||
|
raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
|
||||||
|
|
||||||
|
def list_objects(self, prefix: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
List objects under a given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Object key prefix to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of object keys matching the prefix
|
||||||
|
"""
|
||||||
|
keys = []
|
||||||
|
paginator = self.client.get_paginator("list_objects_v2")
|
||||||
|
|
||||||
|
try:
|
||||||
|
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
|
||||||
|
for obj in page.get("Contents", []):
|
||||||
|
keys.append(obj["Key"])
|
||||||
|
except ClientError as e:
|
||||||
|
raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
|
||||||
|
|
||||||
|
return keys
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _content_md5(data: bytes) -> str:
|
||||||
|
"""Calculate base64-encoded MD5 for Content-MD5 header."""
|
||||||
|
return base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
|
||||||
|
"""
|
||||||
|
Serialize records to gzipped JSONL format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
records: List of dictionaries to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Gzipped JSONL bytes
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
for record in records:
|
||||||
|
# Convert datetime objects to ISO format strings
|
||||||
|
serialized = ArchiveStorage._serialize_record(record)
|
||||||
|
lines.append(orjson.dumps(serialized))
|
||||||
|
|
||||||
|
jsonl_content = b"\n".join(lines)
|
||||||
|
if jsonl_content:
|
||||||
|
jsonl_content += b"\n"
|
||||||
|
|
||||||
|
return gzip.compress(jsonl_content)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Deserialize gzipped JSONL data to records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Gzipped JSONL bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries
|
||||||
|
"""
|
||||||
|
jsonl_content = gzip.decompress(data)
|
||||||
|
records = []
|
||||||
|
|
||||||
|
for line in jsonl_content.splitlines():
|
||||||
|
if line:
|
||||||
|
records.append(orjson.loads(line))
|
||||||
|
|
||||||
|
return records
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Serialize a single record, converting special types."""
|
||||||
|
|
||||||
|
def _serialize(item: Any) -> Any:
|
||||||
|
if isinstance(item, datetime.datetime):
|
||||||
|
return item.isoformat()
|
||||||
|
if isinstance(item, dict):
|
||||||
|
return {key: _serialize(value) for key, value in item.items()}
|
||||||
|
if isinstance(item, list):
|
||||||
|
return [_serialize(value) for value in item]
|
||||||
|
return item
|
||||||
|
|
||||||
|
return cast(dict[str, Any], _serialize(record))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_checksum(data: bytes) -> str:
|
||||||
|
"""Compute MD5 checksum of data."""
|
||||||
|
return hashlib.md5(data).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance (lazy initialization)
|
||||||
|
_archive_storage: ArchiveStorage | None = None
|
||||||
|
_export_storage: ArchiveStorage | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_archive_storage() -> ArchiveStorage:
|
||||||
|
"""
|
||||||
|
Get the archive storage singleton instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ArchiveStorage instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ArchiveStorageNotConfiguredError: If archive storage is not configured
|
||||||
|
"""
|
||||||
|
global _archive_storage
|
||||||
|
if _archive_storage is None:
|
||||||
|
archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
|
||||||
|
if not archive_bucket:
|
||||||
|
raise ArchiveStorageNotConfiguredError(
|
||||||
|
"Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
|
||||||
|
)
|
||||||
|
_archive_storage = ArchiveStorage(bucket=archive_bucket)
|
||||||
|
return _archive_storage
|
||||||
|
|
||||||
|
|
||||||
|
def get_export_storage() -> ArchiveStorage:
|
||||||
|
"""
|
||||||
|
Get the export storage singleton instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ArchiveStorage instance
|
||||||
|
"""
|
||||||
|
global _export_storage
|
||||||
|
if _export_storage is None:
|
||||||
|
export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
|
||||||
|
if not export_bucket:
|
||||||
|
raise ArchiveStorageNotConfiguredError(
|
||||||
|
"Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
|
||||||
|
)
|
||||||
|
_export_storage = ArchiveStorage(bucket=export_bucket)
|
||||||
|
return _export_storage
|
||||||
|
|
@ -8,7 +8,7 @@ from uuid import uuid4
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import DateTime, String, func, select
|
from sqlalchemy import DateTime, String, func, select
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from .base import TypeBase
|
from .base import TypeBase
|
||||||
|
|
@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
|
||||||
role: TenantAccountRole | None = field(default=None, init=False)
|
role: TenantAccountRole | None = field(default=None, init=False)
|
||||||
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||||
|
|
||||||
|
@validates("status")
|
||||||
|
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
|
||||||
|
if isinstance(value, AccountStatus):
|
||||||
|
return value.value
|
||||||
|
return value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_password_set(self):
|
def is_password_set(self):
|
||||||
return self.password is not None
|
return self.password is not None
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,11 @@ celery_redis = Redis(
|
||||||
port=redis_config.get("port") or 6379,
|
port=redis_config.get("port") or 6379,
|
||||||
password=redis_config.get("password") or None,
|
password=redis_config.get("password") or None,
|
||||||
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
|
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
|
||||||
|
ssl=bool(dify_config.BROKER_USE_SSL),
|
||||||
|
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
|
||||||
|
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||||
|
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||||
|
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@ from enums.quota_type import QuotaType, unlimited
|
||||||
from extensions.otel import AppGenerateHandler, trace_span
|
from extensions.otel import AppGenerateHandler, trace_span
|
||||||
from models.model import Account, App, AppMode, EndUser
|
from models.model import Account, App, AppMode, EndUser
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||||
|
from services.errors.llm import InvokeRateLimitError
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from models.model import App, EndUser
|
||||||
from models.trigger import WorkflowTriggerLog
|
from models.trigger import WorkflowTriggerLog
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
|
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
@ -141,7 +141,7 @@ class AsyncWorkflowService:
|
||||||
trigger_log_repo.update(trigger_log)
|
trigger_log_repo.update(trigger_log)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
raise InvokeRateLimitError(
|
raise WorkflowQuotaLimitError(
|
||||||
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
|
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
@ -31,6 +32,11 @@ class BillingService:
|
||||||
|
|
||||||
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
|
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
|
||||||
|
|
||||||
|
# Redis key prefix for tenant plan cache
|
||||||
|
_PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
|
||||||
|
# Cache TTL: 10 minutes
|
||||||
|
_PLAN_CACHE_TTL = 600
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_info(cls, tenant_id: str):
|
def get_info(cls, tenant_id: str):
|
||||||
params = {"tenant_id": tenant_id}
|
params = {"tenant_id": tenant_id}
|
||||||
|
|
@ -272,14 +278,110 @@ class BillingService:
|
||||||
data = resp.get("data", {})
|
data = resp.get("data", {})
|
||||||
|
|
||||||
for tenant_id, plan in data.items():
|
for tenant_id, plan in data.items():
|
||||||
subscription_plan = subscription_adapter.validate_python(plan)
|
try:
|
||||||
results[tenant_id] = subscription_plan
|
subscription_plan = subscription_adapter.validate_python(plan)
|
||||||
|
results[tenant_id] = subscription_plan
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id
|
||||||
|
)
|
||||||
|
continue
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
|
logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _make_plan_cache_key(cls, tenant_id: str) -> str:
|
||||||
|
return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
|
||||||
|
"""
|
||||||
|
Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios.
|
||||||
|
|
||||||
|
NOTE: if you want to high data consistency, use get_plan_bulk instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping of tenant_id -> {plan: str, expiration_date: int}
|
||||||
|
"""
|
||||||
|
tenant_plans: dict[str, SubscriptionPlan] = {}
|
||||||
|
|
||||||
|
if not tenant_ids:
|
||||||
|
return tenant_plans
|
||||||
|
|
||||||
|
subscription_adapter = TypeAdapter(SubscriptionPlan)
|
||||||
|
|
||||||
|
# Step 1: Batch fetch from Redis cache using mget
|
||||||
|
redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids]
|
||||||
|
try:
|
||||||
|
cached_values = redis_client.mget(redis_keys)
|
||||||
|
|
||||||
|
if len(cached_values) != len(tenant_ids):
|
||||||
|
raise Exception(
|
||||||
|
"get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map cached values back to tenant_ids
|
||||||
|
cache_misses: list[str] = []
|
||||||
|
|
||||||
|
for tenant_id, cached_value in zip(tenant_ids, cached_values):
|
||||||
|
if cached_value:
|
||||||
|
try:
|
||||||
|
# Redis returns bytes, decode to string and parse JSON
|
||||||
|
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
|
||||||
|
plan_dict = json.loads(json_str)
|
||||||
|
subscription_plan = subscription_adapter.validate_python(plan_dict)
|
||||||
|
tenant_plans[tenant_id] = subscription_plan
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id
|
||||||
|
)
|
||||||
|
cache_misses.append(tenant_id)
|
||||||
|
else:
|
||||||
|
cache_misses.append(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"get_plan_bulk_with_cache: cache hits=%s, cache misses=%s",
|
||||||
|
len(tenant_plans),
|
||||||
|
len(cache_misses),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API")
|
||||||
|
cache_misses = list(tenant_ids)
|
||||||
|
|
||||||
|
# Step 2: Fetch missing plans from billing API
|
||||||
|
if cache_misses:
|
||||||
|
bulk_plans = BillingService.get_plan_bulk(cache_misses)
|
||||||
|
|
||||||
|
if bulk_plans:
|
||||||
|
plans_to_cache: dict[str, SubscriptionPlan] = {}
|
||||||
|
|
||||||
|
for tenant_id, subscription_plan in bulk_plans.items():
|
||||||
|
tenant_plans[tenant_id] = subscription_plan
|
||||||
|
plans_to_cache[tenant_id] = subscription_plan
|
||||||
|
|
||||||
|
# Step 3: Batch update Redis cache using pipeline
|
||||||
|
if plans_to_cache:
|
||||||
|
try:
|
||||||
|
pipe = redis_client.pipeline()
|
||||||
|
for tenant_id, subscription_plan in plans_to_cache.items():
|
||||||
|
redis_key = cls._make_plan_cache_key(tenant_id)
|
||||||
|
# Serialize dict to JSON string
|
||||||
|
json_str = json.dumps(subscription_plan)
|
||||||
|
pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str)
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"get_plan_bulk_with_cache: cached %s new tenant plans to Redis",
|
||||||
|
len(plans_to_cache),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("get_plan_bulk_with_cache: redis pipeline failed")
|
||||||
|
|
||||||
|
return tenant_plans
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
|
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
|
||||||
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
|
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
|
||||||
|
|
|
||||||
|
|
@ -110,5 +110,5 @@ class EnterpriseService:
|
||||||
if not app_id:
|
if not app_id:
|
||||||
raise ValueError("app_id must be provided.")
|
raise ValueError("app_id must be provided.")
|
||||||
|
|
||||||
body = {"appId": app_id}
|
params = {"appId": app_id}
|
||||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvokeRateLimitError(Exception):
|
class WorkflowQuotaLimitError(Exception):
|
||||||
"""Raised when rate limit is exceeded for workflow invocations."""
|
"""Raised when workflow execution quota is exceeded (for async/background workflows)."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ class PluginParameterService:
|
||||||
provider,
|
provider,
|
||||||
action,
|
action,
|
||||||
resolved_credentials,
|
resolved_credentials,
|
||||||
CredentialType.API_KEY.value,
|
original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
|
||||||
parameter,
|
parameter,
|
||||||
)
|
)
|
||||||
.options
|
.options
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from httpx import get
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||||
|
|
@ -86,7 +85,9 @@ class ApiToolManageService:
|
||||||
raise ValueError(f"invalid schema: {str(e)}")
|
raise ValueError(f"invalid schema: {str(e)}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
def convert_schema_to_tool_bundles(
|
||||||
|
schema: str, extra_info: dict | None = None
|
||||||
|
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||||
"""
|
"""
|
||||||
convert schema to tool bundles
|
convert schema to tool bundles
|
||||||
|
|
||||||
|
|
@ -104,7 +105,7 @@ class ApiToolManageService:
|
||||||
provider_name: str,
|
provider_name: str,
|
||||||
icon: dict,
|
icon: dict,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
schema_type: str,
|
schema_type: ApiProviderSchemaType,
|
||||||
schema: str,
|
schema: str,
|
||||||
privacy_policy: str,
|
privacy_policy: str,
|
||||||
custom_disclaimer: str,
|
custom_disclaimer: str,
|
||||||
|
|
@ -113,9 +114,6 @@ class ApiToolManageService:
|
||||||
"""
|
"""
|
||||||
create api tool provider
|
create api tool provider
|
||||||
"""
|
"""
|
||||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
|
||||||
raise ValueError(f"invalid schema type {schema}")
|
|
||||||
|
|
||||||
provider_name = provider_name.strip()
|
provider_name = provider_name.strip()
|
||||||
|
|
||||||
# check if the provider exists
|
# check if the provider exists
|
||||||
|
|
@ -178,9 +176,6 @@ class ApiToolManageService:
|
||||||
# update labels
|
# update labels
|
||||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -245,18 +240,15 @@ class ApiToolManageService:
|
||||||
original_provider: str,
|
original_provider: str,
|
||||||
icon: dict,
|
icon: dict,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
schema_type: str,
|
_schema_type: ApiProviderSchemaType,
|
||||||
schema: str,
|
schema: str,
|
||||||
privacy_policy: str,
|
privacy_policy: str | None,
|
||||||
custom_disclaimer: str,
|
custom_disclaimer: str,
|
||||||
labels: list[str],
|
labels: list[str],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
update api tool provider
|
update api tool provider
|
||||||
"""
|
"""
|
||||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
|
||||||
raise ValueError(f"invalid schema type {schema}")
|
|
||||||
|
|
||||||
provider_name = provider_name.strip()
|
provider_name = provider_name.strip()
|
||||||
|
|
||||||
# check if the provider exists
|
# check if the provider exists
|
||||||
|
|
@ -281,7 +273,7 @@ class ApiToolManageService:
|
||||||
provider.icon = json.dumps(icon)
|
provider.icon = json.dumps(icon)
|
||||||
provider.schema = schema
|
provider.schema = schema
|
||||||
provider.description = extra_info.get("description", "")
|
provider.description = extra_info.get("description", "")
|
||||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
|
provider.schema_type_str = schema_type
|
||||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||||
provider.privacy_policy = privacy_policy
|
provider.privacy_policy = privacy_policy
|
||||||
provider.custom_disclaimer = custom_disclaimer
|
provider.custom_disclaimer = custom_disclaimer
|
||||||
|
|
@ -322,9 +314,6 @@ class ApiToolManageService:
|
||||||
# update labels
|
# update labels
|
||||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -347,9 +336,6 @@ class ApiToolManageService:
|
||||||
db.session.delete(provider)
|
db.session.delete(provider)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -366,7 +352,7 @@ class ApiToolManageService:
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
parameters: dict,
|
parameters: dict,
|
||||||
schema_type: str,
|
schema_type: ApiProviderSchemaType,
|
||||||
schema: str,
|
schema: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||||
from core.helper.name_generator import generate_incremental_name
|
from core.helper.name_generator import generate_incremental_name
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
|
|
@ -205,9 +204,6 @@ class BuiltinToolManageService:
|
||||||
db_provider.name = name
|
db_provider.name = name
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
@ -290,8 +286,6 @@ class BuiltinToolManageService:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -409,9 +403,6 @@ class BuiltinToolManageService:
|
||||||
)
|
)
|
||||||
cache.delete()
|
cache.delete()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -434,8 +425,6 @@ class BuiltinToolManageService:
|
||||||
target_provider.is_default = True
|
target_provider.is_default = True
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -319,8 +319,14 @@ class MCPToolManageService:
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||||
|
|
||||||
# Update database with retrieved tools
|
# Update database with retrieved tools (ensure description is a non-null string)
|
||||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
tools_payload = []
|
||||||
|
for tool in tools:
|
||||||
|
data = tool.model_dump()
|
||||||
|
if data.get("description") is None:
|
||||||
|
data["description"] = ""
|
||||||
|
tools_payload.append(data)
|
||||||
|
db_provider.tools = json.dumps(tools_payload)
|
||||||
db_provider.authed = True
|
db_provider.authed = True
|
||||||
db_provider.updated_at = datetime.now()
|
db_provider.updated_at = datetime.now()
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
|
@ -620,6 +626,21 @@ class MCPToolManageService:
|
||||||
server_url_hash=new_server_url_hash,
|
server_url_hash=new_server_url_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reconnect_with_url(
|
||||||
|
*,
|
||||||
|
server_url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
timeout: float | None,
|
||||||
|
sse_read_timeout: float | None,
|
||||||
|
) -> ReconnectResult:
|
||||||
|
return MCPToolManageService._reconnect_with_url(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
sse_read_timeout=sse_read_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reconnect_with_url(
|
def _reconnect_with_url(
|
||||||
*,
|
*,
|
||||||
|
|
@ -642,9 +663,16 @@ class MCPToolManageService:
|
||||||
sse_read_timeout=sse_read_timeout,
|
sse_read_timeout=sse_read_timeout,
|
||||||
) as mcp_client:
|
) as mcp_client:
|
||||||
tools = mcp_client.list_tools()
|
tools = mcp_client.list_tools()
|
||||||
|
# Ensure tool descriptions are non-null in payload
|
||||||
|
tools_payload = []
|
||||||
|
for t in tools:
|
||||||
|
d = t.model_dump()
|
||||||
|
if d.get("description") is None:
|
||||||
|
d["description"] = ""
|
||||||
|
tools_payload.append(d)
|
||||||
return ReconnectResult(
|
return ReconnectResult(
|
||||||
authed=True,
|
authed=True,
|
||||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
tools=json.dumps(tools_payload),
|
||||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||||
)
|
)
|
||||||
except MCPAuthError:
|
except MCPAuthError:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
|
||||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
@ -16,14 +15,6 @@ class ToolCommonService:
|
||||||
|
|
||||||
:return: the list of tool providers
|
:return: the list of tool providers
|
||||||
"""
|
"""
|
||||||
# Try to get from cache first
|
|
||||||
cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
|
|
||||||
if cached_result is not None:
|
|
||||||
logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
|
|
||||||
return cached_result
|
|
||||||
|
|
||||||
# Cache miss - fetch from database
|
|
||||||
logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
|
|
||||||
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
|
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
|
||||||
|
|
||||||
# add icon
|
# add icon
|
||||||
|
|
@ -32,7 +23,4 @@ class ToolCommonService:
|
||||||
|
|
||||||
result = [provider.to_dict() for provider in providers]
|
result = [provider.to_dict() for provider in providers]
|
||||||
|
|
||||||
# Cache the result
|
|
||||||
ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from typing import Any
|
||||||
from sqlalchemy import or_, select
|
from sqlalchemy import or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||||
|
|
@ -68,34 +67,31 @@ class WorkflowToolManageService:
|
||||||
if workflow is None:
|
if workflow is None:
|
||||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
workflow_tool_provider = WorkflowToolProvider(
|
||||||
workflow_tool_provider = WorkflowToolProvider(
|
tenant_id=tenant_id,
|
||||||
tenant_id=tenant_id,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
app_id=workflow_app_id,
|
||||||
app_id=workflow_app_id,
|
name=name,
|
||||||
name=name,
|
label=label,
|
||||||
label=label,
|
icon=json.dumps(icon),
|
||||||
icon=json.dumps(icon),
|
description=description,
|
||||||
description=description,
|
parameter_configuration=json.dumps(parameters),
|
||||||
parameter_configuration=json.dumps(parameters),
|
privacy_policy=privacy_policy,
|
||||||
privacy_policy=privacy_policy,
|
version=workflow.version,
|
||||||
version=workflow.version,
|
)
|
||||||
)
|
|
||||||
session.add(workflow_tool_provider)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||||
|
session.add(workflow_tool_provider)
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
ToolLabelManager.update_tool_labels(
|
ToolLabelManager.update_tool_labels(
|
||||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -183,9 +179,6 @@ class WorkflowToolManageService:
|
||||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -248,9 +241,6 @@ class WorkflowToolManageService:
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -868,48 +868,111 @@ class TriggerProviderService:
|
||||||
if not provider_controller:
|
if not provider_controller:
|
||||||
raise ValueError(f"Provider {provider_id} not found")
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
subscription = TriggerProviderService.get_subscription_by_id(
|
# Use distributed lock to prevent race conditions on the same subscription
|
||||||
tenant_id=tenant_id,
|
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
|
||||||
subscription_id=subscription_id,
|
with redis_client.lock(lock_key, timeout=20):
|
||||||
)
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
if not subscription:
|
try:
|
||||||
raise ValueError(f"Subscription {subscription_id} not found")
|
# Get subscription within the transaction
|
||||||
|
subscription: TriggerSubscription | None = (
|
||||||
|
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise ValueError(f"Subscription {subscription_id} not found")
|
||||||
|
|
||||||
credential_type = CredentialType.of(subscription.credential_type)
|
credential_type = CredentialType.of(subscription.credential_type)
|
||||||
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||||
raise ValueError("Credential type not supported for rebuild")
|
raise ValueError("Credential type not supported for rebuild")
|
||||||
|
|
||||||
# TODO: Trying to invoke update api of the plugin trigger provider
|
# Decrypt existing credentials for merging
|
||||||
|
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||||
|
|
||||||
# FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
|
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
|
||||||
|
merged_credentials: dict[str, Any] = {
|
||||||
|
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in credentials.items()
|
||||||
|
}
|
||||||
|
|
||||||
# Delete the previous subscription
|
user_id = subscription.user_id
|
||||||
user_id = subscription.user_id
|
|
||||||
TriggerManager.unsubscribe_trigger(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=user_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
subscription=subscription.to_entity(),
|
|
||||||
credentials=subscription.credentials,
|
|
||||||
credential_type=credential_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a new subscription with the same subscription_id and endpoint_id
|
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
|
||||||
tenant_id=tenant_id,
|
# FALLBACK: If the update api is not implemented,
|
||||||
user_id=user_id,
|
# delete the previous subscription and create a new one
|
||||||
provider_id=provider_id,
|
|
||||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
# Unsubscribe the previous subscription (external call, but we'll handle errors)
|
||||||
parameters=parameters,
|
try:
|
||||||
credentials=credentials,
|
TriggerManager.unsubscribe_trigger(
|
||||||
credential_type=credential_type,
|
tenant_id=tenant_id,
|
||||||
)
|
user_id=user_id,
|
||||||
TriggerProviderService.update_trigger_subscription(
|
provider_id=provider_id,
|
||||||
tenant_id=tenant_id,
|
subscription=subscription.to_entity(),
|
||||||
subscription_id=subscription.id,
|
credentials=decrypted_credentials,
|
||||||
name=name,
|
credential_type=credential_type,
|
||||||
parameters=parameters,
|
)
|
||||||
credentials=credentials,
|
except Exception as e:
|
||||||
properties=new_subscription.properties,
|
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
|
||||||
expires_at=new_subscription.expires_at,
|
# Continue anyway - the subscription might already be deleted externally
|
||||||
)
|
|
||||||
|
# Create a new subscription with the same subscription_id and endpoint_id (external call)
|
||||||
|
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||||
|
parameters=parameters,
|
||||||
|
credentials=merged_credentials,
|
||||||
|
credential_type=credential_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the subscription in the same transaction
|
||||||
|
# Inline update logic to reuse the same session
|
||||||
|
if name is not None and name != subscription.name:
|
||||||
|
existing = (
|
||||||
|
session.query(TriggerSubscription)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing and existing.id != subscription.id:
|
||||||
|
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||||
|
subscription.name = name
|
||||||
|
|
||||||
|
# Update parameters
|
||||||
|
subscription.parameters = dict(parameters)
|
||||||
|
|
||||||
|
# Update credentials with merged (and encrypted) values
|
||||||
|
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
|
||||||
|
|
||||||
|
# Update properties
|
||||||
|
if new_subscription.properties:
|
||||||
|
properties_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_properties_schema(),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
|
||||||
|
|
||||||
|
# Update expiration timestamp
|
||||||
|
if new_subscription.expires_at is not None:
|
||||||
|
subscription.expires_at = new_subscription.expires_at
|
||||||
|
|
||||||
|
# Commit the transaction
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Clear subscription cache
|
||||||
|
delete_cache_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=subscription.provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Rollback on any error
|
||||||
|
session.rollback()
|
||||||
|
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
|
||||||
|
|
@ -863,10 +863,18 @@ class WebhookService:
|
||||||
not_found_in_cache.append(node_id)
|
not_found_in_cache.append(node_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
|
||||||
try:
|
lock = redis_client.lock(lock_key, timeout=10)
|
||||||
# lock the concurrent webhook trigger creation
|
lock_acquired = False
|
||||||
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
|
||||||
|
try:
|
||||||
|
# acquire the lock with blocking and timeout
|
||||||
|
lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
|
||||||
|
if not lock_acquired:
|
||||||
|
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
|
||||||
|
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
# fetch the non-cached nodes from DB
|
# fetch the non-cached nodes from DB
|
||||||
all_records = session.scalars(
|
all_records = session.scalars(
|
||||||
select(WorkflowWebhookTrigger).where(
|
select(WorkflowWebhookTrigger).where(
|
||||||
|
|
@ -903,11 +911,16 @@ class WebhookService:
|
||||||
session.delete(nodes_id_in_db[node_id])
|
session.delete(nodes_id_in_db[node_id])
|
||||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||||
session.commit()
|
session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
|
# release the lock only if it was acquired
|
||||||
|
if lock_acquired:
|
||||||
|
try:
|
||||||
|
lock.release()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to release lock for webhook sync, app %s", app.id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_webhook_id(cls) -> str:
|
def generate_webhook_id(cls) -> str:
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,14 @@ CODE_LANGUAGE = CodeLanguage.JINJA2
|
||||||
|
|
||||||
|
|
||||||
def test_jinja2():
|
def test_jinja2():
|
||||||
|
"""Test basic Jinja2 template rendering."""
|
||||||
template = "Hello {{template}}"
|
template = "Hello {{template}}"
|
||||||
|
# Template must be base64 encoded to match the new safe embedding approach
|
||||||
|
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||||
code = (
|
code = (
|
||||||
Jinja2TemplateTransformer.get_runner_script()
|
Jinja2TemplateTransformer.get_runner_script()
|
||||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
|
||||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||||
)
|
)
|
||||||
result = CodeExecutor.execute_code(
|
result = CodeExecutor.execute_code(
|
||||||
|
|
@ -21,6 +24,7 @@ def test_jinja2():
|
||||||
|
|
||||||
|
|
||||||
def test_jinja2_with_code_template():
|
def test_jinja2_with_code_template():
|
||||||
|
"""Test template rendering via the high-level workflow API."""
|
||||||
result = CodeExecutor.execute_workflow_code_template(
|
result = CodeExecutor.execute_workflow_code_template(
|
||||||
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
|
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
|
||||||
)
|
)
|
||||||
|
|
@ -28,7 +32,64 @@ def test_jinja2_with_code_template():
|
||||||
|
|
||||||
|
|
||||||
def test_jinja2_get_runner_script():
|
def test_jinja2_get_runner_script():
|
||||||
|
"""Test that runner script contains required placeholders."""
|
||||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
|
||||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_jinja2_template_with_special_characters():
|
||||||
|
"""
|
||||||
|
Test that templates with special characters (quotes, newlines) render correctly.
|
||||||
|
This is a regression test for issue #26818 where textarea pre-fill values
|
||||||
|
containing special characters would break template rendering.
|
||||||
|
"""
|
||||||
|
# Template with triple quotes, single quotes, double quotes, and newlines
|
||||||
|
template = """<html>
|
||||||
|
<body>
|
||||||
|
<input value="{{ task.get('Task ID', '') }}"/>
|
||||||
|
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
|
||||||
|
<p>Status: "{{ status }}"</p>
|
||||||
|
<pre>'''code block'''</pre>
|
||||||
|
</body>
|
||||||
|
</html>"""
|
||||||
|
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
|
||||||
|
|
||||||
|
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
|
||||||
|
|
||||||
|
# Verify the template rendered correctly with all special characters
|
||||||
|
output = result["result"]
|
||||||
|
assert 'value="TASK-123"' in output
|
||||||
|
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
|
||||||
|
assert 'Status: "completed"' in output
|
||||||
|
assert "'''code block'''" in output
|
||||||
|
|
||||||
|
|
||||||
|
def test_jinja2_template_with_html_textarea_prefill():
|
||||||
|
"""
|
||||||
|
Specific test for HTML textarea with Jinja2 variable pre-fill.
|
||||||
|
Verifies fix for issue #26818.
|
||||||
|
"""
|
||||||
|
template = "<textarea name='notes'>{{ notes }}</textarea>"
|
||||||
|
notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes."
|
||||||
|
inputs = {"notes": notes_content}
|
||||||
|
|
||||||
|
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
|
||||||
|
|
||||||
|
expected_output = f"<textarea name='notes'>{notes_content}</textarea>"
|
||||||
|
assert result["result"] == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_jinja2_assemble_runner_script_encodes_template():
|
||||||
|
"""Test that assemble_runner_script properly base64 encodes the template."""
|
||||||
|
template = "Hello {{ name }}!"
|
||||||
|
inputs = {"name": "World"}
|
||||||
|
|
||||||
|
script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs)
|
||||||
|
|
||||||
|
# The template should be base64 encoded in the script
|
||||||
|
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||||
|
assert template_b64 in script
|
||||||
|
# The raw template should NOT appear in the script (it's encoded)
|
||||||
|
assert "Hello {{ name }}!" not in script
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,365 @@
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
|
||||||
|
class TestBillingServiceGetPlanBulkWithCache:
|
||||||
|
"""
|
||||||
|
Comprehensive integration tests for get_plan_bulk_with_cache using testcontainers.
|
||||||
|
|
||||||
|
This test class covers all major scenarios:
|
||||||
|
- Cache hit/miss scenarios
|
||||||
|
- Redis operation failures and fallback behavior
|
||||||
|
- Invalid cache data handling
|
||||||
|
- TTL expiration handling
|
||||||
|
- Error recovery and logging
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_redis_cleanup(self, flask_app_with_containers):
|
||||||
|
"""Clean up Redis cache before and after each test."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Clean up before test
|
||||||
|
yield
|
||||||
|
# Clean up after test
|
||||||
|
# Delete all test cache keys
|
||||||
|
pattern = f"{BillingService._PLAN_CACHE_KEY_PREFIX}*"
|
||||||
|
keys = redis_client.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
redis_client.delete(*keys)
|
||||||
|
|
||||||
|
def _create_test_plan_data(self, plan: str = "sandbox", expiration_date: int = 1735689600):
|
||||||
|
"""Helper to create test SubscriptionPlan data."""
|
||||||
|
return {"plan": plan, "expiration_date": expiration_date}
|
||||||
|
|
||||||
|
def _set_cache(self, tenant_id: str, plan_data: dict, ttl: int = 600):
|
||||||
|
"""Helper to set cache data in Redis."""
|
||||||
|
cache_key = BillingService._make_plan_cache_key(tenant_id)
|
||||||
|
json_str = json.dumps(plan_data)
|
||||||
|
redis_client.setex(cache_key, ttl, json_str)
|
||||||
|
|
||||||
|
def _get_cache(self, tenant_id: str):
|
||||||
|
"""Helper to get cache data from Redis."""
|
||||||
|
cache_key = BillingService._make_plan_cache_key(tenant_id)
|
||||||
|
value = redis_client.get(cache_key)
|
||||||
|
if value:
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
return value.decode("utf-8")
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers):
|
||||||
|
"""Test bulk plan retrieval when all tenants are in cache."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Arrange
|
||||||
|
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||||
|
expected_plans = {
|
||||||
|
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||||
|
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||||
|
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pre-populate cache
|
||||||
|
for tenant_id, plan_data in expected_plans.items():
|
||||||
|
self._set_cache(tenant_id, plan_data)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
|
||||||
|
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result["tenant-1"]["plan"] == "sandbox"
|
||||||
|
assert result["tenant-1"]["expiration_date"] == 1735689600
|
||||||
|
assert result["tenant-2"]["plan"] == "professional"
|
||||||
|
assert result["tenant-2"]["expiration_date"] == 1767225600
|
||||||
|
assert result["tenant-3"]["plan"] == "team"
|
||||||
|
assert result["tenant-3"]["expiration_date"] == 1798761600
|
||||||
|
|
||||||
|
# Verify API was not called
|
||||||
|
mock_get_plan_bulk.assert_not_called()
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers):
|
||||||
|
"""Test bulk plan retrieval when all tenants are not in cache."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Arrange
|
||||||
|
tenant_ids = ["tenant-1", "tenant-2"]
|
||||||
|
expected_plans = {
|
||||||
|
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||||
|
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||||
|
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result["tenant-1"]["plan"] == "sandbox"
|
||||||
|
assert result["tenant-2"]["plan"] == "professional"
|
||||||
|
|
||||||
|
# Verify API was called with correct tenant_ids
|
||||||
|
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||||
|
|
||||||
|
# Verify data was written to cache
|
||||||
|
cached_1 = self._get_cache("tenant-1")
|
||||||
|
cached_2 = self._get_cache("tenant-2")
|
||||||
|
assert cached_1 is not None
|
||||||
|
assert cached_2 is not None
|
||||||
|
|
||||||
|
# Verify cache content
|
||||||
|
cached_data_1 = json.loads(cached_1)
|
||||||
|
cached_data_2 = json.loads(cached_2)
|
||||||
|
assert cached_data_1 == expected_plans["tenant-1"]
|
||||||
|
assert cached_data_2 == expected_plans["tenant-2"]
|
||||||
|
|
||||||
|
# Verify TTL is set
|
||||||
|
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
|
||||||
|
ttl_1 = redis_client.ttl(cache_key_1)
|
||||||
|
assert ttl_1 > 0
|
||||||
|
assert ttl_1 <= 600 # Should be <= 600 seconds
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers):
|
||||||
|
"""Test bulk plan retrieval when some tenants are in cache, some are not."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Arrange
|
||||||
|
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||||
|
# Pre-populate cache for tenant-1 and tenant-2
|
||||||
|
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||||
|
self._set_cache("tenant-2", self._create_test_plan_data("professional", 1767225600))
|
||||||
|
|
||||||
|
# tenant-3 is not in cache
|
||||||
|
missing_plan = {"tenant-3": self._create_test_plan_data("team", 1798761600)}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch.object(BillingService, "get_plan_bulk", return_value=missing_plan) as mock_get_plan_bulk:
|
||||||
|
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result["tenant-1"]["plan"] == "sandbox"
|
||||||
|
assert result["tenant-2"]["plan"] == "professional"
|
||||||
|
assert result["tenant-3"]["plan"] == "team"
|
||||||
|
|
||||||
|
# Verify API was called only for missing tenant
|
||||||
|
mock_get_plan_bulk.assert_called_once_with(["tenant-3"])
|
||||||
|
|
||||||
|
# Verify tenant-3 data was written to cache
|
||||||
|
cached_3 = self._get_cache("tenant-3")
|
||||||
|
assert cached_3 is not None
|
||||||
|
cached_data_3 = json.loads(cached_3)
|
||||||
|
assert cached_data_3 == missing_plan["tenant-3"]
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers):
|
||||||
|
"""Test fallback to API when Redis mget fails."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Arrange
|
||||||
|
tenant_ids = ["tenant-1", "tenant-2"]
|
||||||
|
expected_plans = {
|
||||||
|
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||||
|
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch.object(redis_client, "mget", side_effect=Exception("Redis connection error")),
|
||||||
|
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk,
|
||||||
|
):
|
||||||
|
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result["tenant-1"]["plan"] == "sandbox"
|
||||||
|
assert result["tenant-2"]["plan"] == "professional"
|
||||||
|
|
||||||
|
# Verify API was called for all tenants (fallback)
|
||||||
|
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||||
|
|
||||||
|
# Verify data was written to cache after fallback
|
||||||
|
cached_1 = self._get_cache("tenant-1")
|
||||||
|
cached_2 = self._get_cache("tenant-2")
|
||||||
|
assert cached_1 is not None
|
||||||
|
assert cached_2 is not None
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers):
|
||||||
|
"""Test fallback to API when cache contains invalid JSON."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Arrange
|
||||||
|
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||||
|
|
||||||
|
# Set valid cache for tenant-1
|
||||||
|
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||||
|
|
||||||
|
# Set invalid JSON for tenant-2
|
||||||
|
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||||
|
redis_client.setex(cache_key_2, 600, "invalid json {")
|
||||||
|
|
||||||
|
# tenant-3 is not in cache
|
||||||
|
expected_plans = {
|
||||||
|
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||||
|
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||||
|
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result["tenant-1"]["plan"] == "sandbox" # From cache
|
||||||
|
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
|
||||||
|
assert result["tenant-3"]["plan"] == "team" # From API
|
||||||
|
|
||||||
|
# Verify API was called for tenant-2 and tenant-3
|
||||||
|
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
|
||||||
|
|
||||||
|
# Verify tenant-2's invalid JSON was replaced with correct data in cache
|
||||||
|
cached_2 = self._get_cache("tenant-2")
|
||||||
|
assert cached_2 is not None
|
||||||
|
cached_data_2 = json.loads(cached_2)
|
||||||
|
assert cached_data_2 == expected_plans["tenant-2"]
|
||||||
|
assert cached_data_2["plan"] == "professional"
|
||||||
|
assert cached_data_2["expiration_date"] == 1767225600
|
||||||
|
|
||||||
|
# Verify tenant-2 cache has correct TTL
|
||||||
|
cache_key_2_new = BillingService._make_plan_cache_key("tenant-2")
|
||||||
|
ttl_2 = redis_client.ttl(cache_key_2_new)
|
||||||
|
assert ttl_2 > 0
|
||||||
|
assert ttl_2 <= 600
|
||||||
|
|
||||||
|
# Verify tenant-3 data was also written to cache
|
||||||
|
cached_3 = self._get_cache("tenant-3")
|
||||||
|
assert cached_3 is not None
|
||||||
|
cached_data_3 = json.loads(cached_3)
|
||||||
|
assert cached_data_3 == expected_plans["tenant-3"]
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers):
|
||||||
|
"""Test fallback to API when cache data doesn't match SubscriptionPlan schema."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Arrange
|
||||||
|
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||||
|
|
||||||
|
# Set valid cache for tenant-1
|
||||||
|
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||||
|
|
||||||
|
# Set invalid plan data for tenant-2 (missing expiration_date)
|
||||||
|
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||||
|
invalid_data = json.dumps({"plan": "professional"}) # Missing expiration_date
|
||||||
|
redis_client.setex(cache_key_2, 600, invalid_data)
|
||||||
|
|
||||||
|
# tenant-3 is not in cache
|
||||||
|
expected_plans = {
|
||||||
|
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||||
|
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||||
|
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result["tenant-1"]["plan"] == "sandbox" # From cache
|
||||||
|
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
|
||||||
|
assert result["tenant-3"]["plan"] == "team" # From API
|
||||||
|
|
||||||
|
# Verify API was called for tenant-2 and tenant-3
|
||||||
|
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers):
|
||||||
|
"""Test that pipeline failure doesn't affect return value."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Arrange
|
||||||
|
tenant_ids = ["tenant-1", "tenant-2"]
|
||||||
|
expected_plans = {
|
||||||
|
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||||
|
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans),
|
||||||
|
patch.object(redis_client, "pipeline") as mock_pipeline,
|
||||||
|
):
|
||||||
|
# Create a mock pipeline that fails on execute
|
||||||
|
mock_pipe = mock_pipeline.return_value
|
||||||
|
mock_pipe.execute.side_effect = Exception("Pipeline execution failed")
|
||||||
|
|
||||||
|
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||||
|
|
||||||
|
# Assert - Function should still return correct result despite pipeline failure
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result["tenant-1"]["plan"] == "sandbox"
|
||||||
|
assert result["tenant-2"]["plan"] == "professional"
|
||||||
|
|
||||||
|
# Verify pipeline was attempted
|
||||||
|
mock_pipeline.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers):
|
||||||
|
"""Test with empty tenant_ids list."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Act
|
||||||
|
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
|
||||||
|
result = BillingService.get_plan_bulk_with_cache([])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == {}
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
# Verify no API calls
|
||||||
|
mock_get_plan_bulk.assert_not_called()
|
||||||
|
|
||||||
|
# Verify no Redis operations (mget with empty list would return empty list)
|
||||||
|
# But we should check that mget was not called at all
|
||||||
|
# Since we can't easily verify this without more mocking, we just verify the result
|
||||||
|
|
||||||
|
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers):
|
||||||
|
"""Test that expired cache keys are treated as cache misses."""
|
||||||
|
with flask_app_with_containers.app_context():
|
||||||
|
# Arrange
|
||||||
|
tenant_ids = ["tenant-1", "tenant-2"]
|
||||||
|
|
||||||
|
# Set cache for tenant-1 with very short TTL (1 second) to simulate expiration
|
||||||
|
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600), ttl=1)
|
||||||
|
|
||||||
|
# Wait for TTL to expire (key will be deleted by Redis)
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Verify cache is expired (key doesn't exist)
|
||||||
|
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
|
||||||
|
exists = redis_client.exists(cache_key_1)
|
||||||
|
assert exists == 0 # Key doesn't exist (expired)
|
||||||
|
|
||||||
|
# tenant-2 is not in cache
|
||||||
|
expected_plans = {
|
||||||
|
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||||
|
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||||
|
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result["tenant-1"]["plan"] == "sandbox"
|
||||||
|
assert result["tenant-2"]["plan"] == "professional"
|
||||||
|
|
||||||
|
# Verify API was called for both tenants (tenant-1 expired, tenant-2 missing)
|
||||||
|
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||||
|
|
||||||
|
# Verify both were written to cache with correct TTL
|
||||||
|
cache_key_1_new = BillingService._make_plan_cache_key("tenant-1")
|
||||||
|
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||||
|
ttl_1_new = redis_client.ttl(cache_key_1_new)
|
||||||
|
ttl_2 = redis_client.ttl(cache_key_2)
|
||||||
|
assert ttl_1_new > 0
|
||||||
|
assert ttl_1_new <= 600
|
||||||
|
assert ttl_2 > 0
|
||||||
|
assert ttl_2 <= 600
|
||||||
|
|
@ -0,0 +1,682 @@
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
|
||||||
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.provider_ids import TriggerProviderID
|
||||||
|
from models.trigger import TriggerSubscription
|
||||||
|
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||||
|
|
||||||
|
|
||||||
|
class TestTriggerProviderService:
|
||||||
|
"""Integration tests for TriggerProviderService using testcontainers."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_external_service_dependencies(self):
|
||||||
|
"""Mock setup for external service dependencies."""
|
||||||
|
with (
|
||||||
|
patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager,
|
||||||
|
patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client,
|
||||||
|
patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache,
|
||||||
|
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||||
|
):
|
||||||
|
# Setup default mock returns
|
||||||
|
mock_provider_controller = MagicMock()
|
||||||
|
mock_provider_controller.get_credential_schema_config.return_value = MagicMock()
|
||||||
|
mock_provider_controller.get_properties_schema.return_value = MagicMock()
|
||||||
|
mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller
|
||||||
|
|
||||||
|
# Mock redis lock
|
||||||
|
mock_lock = MagicMock()
|
||||||
|
mock_lock.__enter__ = MagicMock(return_value=None)
|
||||||
|
mock_lock.__exit__ = MagicMock(return_value=None)
|
||||||
|
mock_redis_client.lock.return_value = mock_lock
|
||||||
|
|
||||||
|
# Setup account feature service mock
|
||||||
|
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"trigger_manager": mock_trigger_manager,
|
||||||
|
"redis_client": mock_redis_client,
|
||||||
|
"delete_cache": mock_delete_cache,
|
||||||
|
"provider_controller": mock_provider_controller,
|
||||||
|
"account_feature_service": mock_account_feature_service,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
|
"""
|
||||||
|
Helper method to create a test account and tenant for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_with_containers: Database session from testcontainers infrastructure
|
||||||
|
mock_external_service_dependencies: Mock dependencies
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (account, tenant) - Created account and tenant instances
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
from services.account_service import AccountService, TenantService
|
||||||
|
|
||||||
|
# Setup mocks for account creation
|
||||||
|
mock_external_service_dependencies[
|
||||||
|
"account_feature_service"
|
||||||
|
].get_system_features.return_value.is_allow_register = True
|
||||||
|
mock_external_service_dependencies[
|
||||||
|
"trigger_manager"
|
||||||
|
].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"]
|
||||||
|
|
||||||
|
# Create account and tenant
|
||||||
|
account = AccountService.create_account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
interface_language="en-US",
|
||||||
|
password=fake.password(length=12),
|
||||||
|
)
|
||||||
|
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||||
|
tenant = account.current_tenant
|
||||||
|
|
||||||
|
return account, tenant
|
||||||
|
|
||||||
|
def _create_test_subscription(
|
||||||
|
self,
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant_id,
|
||||||
|
user_id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
credentials,
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper method to create a test trigger subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_with_containers: Database session
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
user_id: User ID
|
||||||
|
provider_id: Provider ID
|
||||||
|
credential_type: Credential type
|
||||||
|
credentials: Credentials dict
|
||||||
|
mock_external_service_dependencies: Mock dependencies
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TriggerSubscription: Created subscription instance
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
|
from core.helper.provider_encryption import create_provider_encrypter
|
||||||
|
|
||||||
|
# Use mock provider controller to encrypt credentials
|
||||||
|
provider_controller = mock_external_service_dependencies["provider_controller"]
|
||||||
|
|
||||||
|
# Create encrypter for credentials
|
||||||
|
credential_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_credential_schema_config(credential_type),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
subscription = TriggerSubscription(
|
||||||
|
name=fake.word(),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider_id=str(provider_id),
|
||||||
|
endpoint_id=fake.uuid4(),
|
||||||
|
parameters={"param1": "value1"},
|
||||||
|
properties={"prop1": "value1"},
|
||||||
|
credentials=dict(credential_encrypter.encrypt(credentials)),
|
||||||
|
credential_type=credential_type.value,
|
||||||
|
credential_expires_at=-1,
|
||||||
|
expires_at=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(subscription)
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(subscription)
|
||||||
|
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_success_with_merged_credentials(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Credentials are properly merged (HIDDEN_VALUE replaced with existing values)
|
||||||
|
- Single transaction wraps all operations
|
||||||
|
- Merged credentials are used for subscribe and update
|
||||||
|
- Database state is correctly updated
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
|
# Create initial subscription with credentials
|
||||||
|
original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"}
|
||||||
|
subscription = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
original_credentials,
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare new credentials with HIDDEN_VALUE for api_key (should keep original)
|
||||||
|
# and new value for api_secret (should update)
|
||||||
|
new_credentials = {
|
||||||
|
"api_key": HIDDEN_VALUE, # Should be replaced with original
|
||||||
|
"api_secret": "new-secret-value", # Should be updated
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock subscribe_trigger to return a new subscription entity
|
||||||
|
new_subscription_entity = TriggerSubscriptionEntity(
|
||||||
|
endpoint=subscription.endpoint_id,
|
||||||
|
parameters={"param1": "value1"},
|
||||||
|
properties={"prop1": "new_prop_value"},
|
||||||
|
expires_at=1234567890,
|
||||||
|
)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||||
|
|
||||||
|
# Mock unsubscribe_trigger
|
||||||
|
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Execute rebuild
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
credentials=new_credentials,
|
||||||
|
parameters={"param1": "updated_value"},
|
||||||
|
name="updated_name",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify unsubscribe was called with decrypted original credentials
|
||||||
|
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once()
|
||||||
|
unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args
|
||||||
|
assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id
|
||||||
|
assert unsubscribe_call_args.kwargs["provider_id"] == provider_id
|
||||||
|
assert unsubscribe_call_args.kwargs["credential_type"] == credential_type
|
||||||
|
|
||||||
|
# Verify subscribe was called with merged credentials (api_key from original, api_secret new)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||||
|
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||||
|
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||||
|
assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original
|
||||||
|
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
|
||||||
|
|
||||||
|
# Verify database state was updated
|
||||||
|
db.session.refresh(subscription)
|
||||||
|
assert subscription.name == "updated_name"
|
||||||
|
assert subscription.parameters == {"param1": "updated_value"}
|
||||||
|
|
||||||
|
# Verify credentials in DB were updated with merged values (decrypt to check)
|
||||||
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
|
from core.helper.provider_encryption import create_provider_encrypter
|
||||||
|
|
||||||
|
# Use mock provider controller to decrypt credentials
|
||||||
|
provider_controller = mock_external_service_dependencies["provider_controller"]
|
||||||
|
credential_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
config=provider_controller.get_credential_schema_config(credential_type),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||||
|
assert decrypted_db_credentials["api_key"] == original_credentials["api_key"]
|
||||||
|
assert decrypted_db_credentials["api_secret"] == "new-secret-value"
|
||||||
|
|
||||||
|
# Verify cache was cleared
|
||||||
|
mock_external_service_dependencies["delete_cache"].assert_called_once_with(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=subscription.provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_with_all_new_credentials(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test rebuild when all credentials are new (no HIDDEN_VALUE).
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- All new credentials are used when no HIDDEN_VALUE is present
|
||||||
|
- Merged credentials contain only new values
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
|
# Create initial subscription
|
||||||
|
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
|
||||||
|
subscription = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
original_credentials,
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All new credentials (no HIDDEN_VALUE)
|
||||||
|
new_credentials = {
|
||||||
|
"api_key": "completely-new-key",
|
||||||
|
"api_secret": "completely-new-secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
new_subscription_entity = TriggerSubscriptionEntity(
|
||||||
|
endpoint=subscription.endpoint_id,
|
||||||
|
parameters={},
|
||||||
|
properties={},
|
||||||
|
expires_at=-1,
|
||||||
|
)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||||
|
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Execute rebuild
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
credentials=new_credentials,
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subscribe was called with all new credentials
|
||||||
|
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||||
|
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||||
|
assert subscribe_credentials["api_key"] == "completely-new-key"
|
||||||
|
assert subscribe_credentials["api_secret"] == "completely-new-secret"
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_with_all_hidden_values(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- All HIDDEN_VALUE credentials are replaced with existing values
|
||||||
|
- Original credentials are preserved
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
|
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
|
||||||
|
subscription = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
original_credentials,
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All HIDDEN_VALUE (should preserve all original)
|
||||||
|
new_credentials = {
|
||||||
|
"api_key": HIDDEN_VALUE,
|
||||||
|
"api_secret": HIDDEN_VALUE,
|
||||||
|
}
|
||||||
|
|
||||||
|
new_subscription_entity = TriggerSubscriptionEntity(
|
||||||
|
endpoint=subscription.endpoint_id,
|
||||||
|
parameters={},
|
||||||
|
properties={},
|
||||||
|
expires_at=-1,
|
||||||
|
)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||||
|
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Execute rebuild
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
credentials=new_credentials,
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subscribe was called with all original credentials
|
||||||
|
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||||
|
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||||
|
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
|
||||||
|
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
|
# Original has only api_key
|
||||||
|
original_credentials = {"api_key": "original-key"}
|
||||||
|
subscription = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
original_credentials,
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
# HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE
|
||||||
|
new_credentials = {
|
||||||
|
"api_key": HIDDEN_VALUE,
|
||||||
|
"non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original
|
||||||
|
}
|
||||||
|
|
||||||
|
new_subscription_entity = TriggerSubscriptionEntity(
|
||||||
|
endpoint=subscription.endpoint_id,
|
||||||
|
parameters={},
|
||||||
|
properties={},
|
||||||
|
expires_at=-1,
|
||||||
|
)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||||
|
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Execute rebuild
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
credentials=new_credentials,
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key
|
||||||
|
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||||
|
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||||
|
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
|
||||||
|
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_rollback_on_error(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that transaction is rolled back on error.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Database transaction is rolled back when an error occurs
|
||||||
|
- Original subscription state is preserved
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
|
original_credentials = {"api_key": "original-key"}
|
||||||
|
subscription = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
original_credentials,
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_name = subscription.name
|
||||||
|
original_parameters = subscription.parameters.copy()
|
||||||
|
|
||||||
|
# Make subscribe_trigger raise an error
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError(
|
||||||
|
"Subscribe failed"
|
||||||
|
)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Execute rebuild and expect error
|
||||||
|
with pytest.raises(ValueError, match="Subscribe failed"):
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
credentials={"api_key": "new-key"},
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subscription state was not changed (rolled back)
|
||||||
|
db.session.refresh(subscription)
|
||||||
|
assert subscription.name == original_name
|
||||||
|
assert subscription.parameters == original_parameters
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that unsubscribe errors are handled gracefully and operation continues.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Unsubscribe errors are caught and logged but don't stop the rebuild
|
||||||
|
- Rebuild continues even if unsubscribe fails
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
|
original_credentials = {"api_key": "original-key"}
|
||||||
|
subscription = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
original_credentials,
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make unsubscribe_trigger raise an error (should be caught and continue)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
|
||||||
|
"Unsubscribe failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_subscription_entity = TriggerSubscriptionEntity(
|
||||||
|
endpoint=subscription.endpoint_id,
|
||||||
|
parameters={},
|
||||||
|
properties={},
|
||||||
|
expires_at=-1,
|
||||||
|
)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||||
|
|
||||||
|
# Execute rebuild - should succeed despite unsubscribe error
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
credentials={"api_key": "new-key"},
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subscribe was still called (operation continued)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||||
|
|
||||||
|
# Verify subscription was updated
|
||||||
|
db.session.refresh(subscription)
|
||||||
|
assert subscription.parameters == {}
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test error when subscription is not found.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error is raised when subscription doesn't exist
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
fake_subscription_id = fake.uuid4()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=fake_subscription_id,
|
||||||
|
credentials={},
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_provider_not_found(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test error when provider is not found.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error is raised when provider doesn't exist
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
|
||||||
|
|
||||||
|
# Make get_trigger_provider return None
|
||||||
|
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Provider.*not found"):
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=fake.uuid4(),
|
||||||
|
credentials={},
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_unsupported_credential_type(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test error when credential type is not supported for rebuild.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
credential_type = CredentialType.UNAUTHORIZED # Not supported
|
||||||
|
|
||||||
|
subscription = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
{},
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
credentials={},
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rebuild_trigger_subscription_name_uniqueness_check(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that name uniqueness is checked when updating name.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Error is raised when new name conflicts with existing subscription
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
account, tenant = self._create_test_account_and_tenant(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||||
|
credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
|
# Create first subscription
|
||||||
|
subscription1 = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
{"api_key": "key1"},
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create second subscription with different name
|
||||||
|
subscription2 = self._create_test_subscription(
|
||||||
|
db_session_with_containers,
|
||||||
|
tenant.id,
|
||||||
|
account.id,
|
||||||
|
provider_id,
|
||||||
|
credential_type,
|
||||||
|
{"api_key": "key2"},
|
||||||
|
mock_external_service_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_subscription_entity = TriggerSubscriptionEntity(
|
||||||
|
endpoint=subscription2.endpoint_id,
|
||||||
|
parameters={},
|
||||||
|
properties={},
|
||||||
|
expires_at=-1,
|
||||||
|
)
|
||||||
|
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||||
|
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Try to rename subscription2 to subscription1's name (should fail)
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription2.id,
|
||||||
|
credentials={"api_key": "new-key"},
|
||||||
|
parameters={},
|
||||||
|
name=subscription1.name, # Conflicting name
|
||||||
|
)
|
||||||
|
|
@ -705,3 +705,207 @@ class TestWorkflowToolManageService:
|
||||||
db.session.refresh(created_tool)
|
db.session.refresh(created_tool)
|
||||||
assert created_tool.name == first_tool_name
|
assert created_tool.name == first_tool_name
|
||||||
assert created_tool.updated_at is not None
|
assert created_tool.updated_at is not None
|
||||||
|
|
||||||
|
def test_create_workflow_tool_with_file_parameter_default(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test workflow tool creation with FILE parameter having a file object as default.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- FILE parameters can have file object defaults
|
||||||
|
- The default value (dict with id/base64Url) is properly handled
|
||||||
|
- Tool creation succeeds without Pydantic validation errors
|
||||||
|
|
||||||
|
Related issue: Array[File] default value causes Pydantic validation errors.
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create workflow graph with a FILE variable that has a default value
|
||||||
|
workflow_graph = {
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "start_node",
|
||||||
|
"data": {
|
||||||
|
"type": "start",
|
||||||
|
"variables": [
|
||||||
|
{
|
||||||
|
"variable": "document",
|
||||||
|
"label": "Document",
|
||||||
|
"type": "file",
|
||||||
|
"required": False,
|
||||||
|
"default": {"id": fake.uuid4(), "base64Url": ""},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
workflow.graph = json.dumps(workflow_graph)
|
||||||
|
|
||||||
|
# Setup workflow tool parameters with FILE type
|
||||||
|
file_parameters = [
|
||||||
|
{
|
||||||
|
"name": "document",
|
||||||
|
"description": "Upload a document",
|
||||||
|
"form": "form",
|
||||||
|
"type": "file",
|
||||||
|
"required": False,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute the method under test
|
||||||
|
# Note: from_db is mocked, so this test primarily validates the parameter configuration
|
||||||
|
result = WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=fake.word(),
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "📄"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=file_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
def test_create_workflow_tool_with_files_parameter_default(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- FILES parameters can have a list of file objects as default
|
||||||
|
- The default value (list of dicts with id/base64Url) is properly handled
|
||||||
|
- Tool creation succeeds without Pydantic validation errors
|
||||||
|
|
||||||
|
Related issue: Array[File] default value causes 4 Pydantic validation errors
|
||||||
|
because PluginParameter.default only accepts Union[float, int, str, bool] | None.
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create workflow graph with a FILE_LIST variable that has a default value
|
||||||
|
workflow_graph = {
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "start_node",
|
||||||
|
"data": {
|
||||||
|
"type": "start",
|
||||||
|
"variables": [
|
||||||
|
{
|
||||||
|
"variable": "documents",
|
||||||
|
"label": "Documents",
|
||||||
|
"type": "file-list",
|
||||||
|
"required": False,
|
||||||
|
"default": [
|
||||||
|
{"id": fake.uuid4(), "base64Url": ""},
|
||||||
|
{"id": fake.uuid4(), "base64Url": ""},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
workflow.graph = json.dumps(workflow_graph)
|
||||||
|
|
||||||
|
# Setup workflow tool parameters with FILES type
|
||||||
|
files_parameters = [
|
||||||
|
{
|
||||||
|
"name": "documents",
|
||||||
|
"description": "Upload multiple documents",
|
||||||
|
"form": "form",
|
||||||
|
"type": "files",
|
||||||
|
"required": False,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute the method under test
|
||||||
|
# Note: from_db is mocked, so this test primarily validates the parameter configuration
|
||||||
|
result = WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=fake.word(),
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "📁"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=files_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
def test_create_workflow_tool_db_commit_before_validation(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that database commit happens before validation, causing DB pollution on validation failure.
|
||||||
|
|
||||||
|
This test verifies the second bug:
|
||||||
|
- WorkflowToolProvider is committed to database BEFORE from_db validation
|
||||||
|
- If validation fails, the record remains in the database
|
||||||
|
- Subsequent attempts fail with "Tool already exists" error
|
||||||
|
|
||||||
|
This demonstrates why we need to validate BEFORE database commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_name = fake.word()
|
||||||
|
|
||||||
|
# Mock from_db to raise validation error
|
||||||
|
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.side_effect = ValueError(
|
||||||
|
"Validation failed: default parameter type mismatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to create workflow tool (will fail at validation stage)
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=tool_name,
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=self._create_test_workflow_tool_parameters(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Validation failed" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Verify the tool was NOT created in database
|
||||||
|
# This is the expected behavior (no pollution)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
tool_count = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
WorkflowToolProvider.name == tool_name,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
# The record should NOT exist because the transaction should be rolled back
|
||||||
|
# Currently, due to the bug, the record might exist (this test documents the bug)
|
||||||
|
# After the fix, this should always be 0
|
||||||
|
# For now, we document that the record may exist, demonstrating the bug
|
||||||
|
# assert tool_count == 0 # Expected after fix
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,12 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||||
|
|
||||||
template = "Hello {{template}}"
|
template = "Hello {{template}}"
|
||||||
|
# Template must be base64 encoded to match the new safe embedding approach
|
||||||
|
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||||
code = (
|
code = (
|
||||||
Jinja2TemplateTransformer.get_runner_script()
|
Jinja2TemplateTransformer.get_runner_script()
|
||||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
|
||||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||||
)
|
)
|
||||||
result = CodeExecutor.execute_code(
|
result = CodeExecutor.execute_code(
|
||||||
|
|
@ -37,6 +39,34 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||||
|
|
||||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
|
||||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||||
|
|
||||||
|
def test_jinja2_template_with_special_characters(self, flask_app_with_containers):
|
||||||
|
"""
|
||||||
|
Test that templates with special characters (quotes, newlines) render correctly.
|
||||||
|
This is a regression test for issue #26818 where textarea pre-fill values
|
||||||
|
containing special characters would break template rendering.
|
||||||
|
"""
|
||||||
|
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||||
|
|
||||||
|
# Template with triple quotes, single quotes, double quotes, and newlines
|
||||||
|
template = """<html>
|
||||||
|
<body>
|
||||||
|
<input value="{{ task.get('Task ID', '') }}"/>
|
||||||
|
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
|
||||||
|
<p>Status: "{{ status }}"</p>
|
||||||
|
<pre>'''code block'''</pre>
|
||||||
|
</body>
|
||||||
|
</html>"""
|
||||||
|
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
|
||||||
|
|
||||||
|
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
|
||||||
|
|
||||||
|
# Verify the template rendered correctly with all special characters
|
||||||
|
output = result["result"]
|
||||||
|
assert 'value="TASK-123"' in output
|
||||||
|
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
|
||||||
|
assert 'Status: "completed"' in output
|
||||||
|
assert "'''code block'''" in output
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
import builtins
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from flask.views import MethodView as FlaskMethodView
|
||||||
|
|
||||||
|
_NEEDS_METHOD_VIEW_CLEANUP = False
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = FlaskMethodView
|
||||||
|
_NEEDS_METHOD_VIEW_CLEANUP = True
|
||||||
|
from controllers.common.fields import Parameters, Site
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
|
from models.model import IconType
|
||||||
|
|
||||||
|
|
||||||
|
def test_parameters_model_round_trip():
|
||||||
|
parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[])
|
||||||
|
|
||||||
|
model = Parameters.model_validate(parameters)
|
||||||
|
|
||||||
|
assert model.model_dump(mode="json") == parameters
|
||||||
|
|
||||||
|
|
||||||
|
def test_site_icon_url_uses_signed_url_for_image_icon():
|
||||||
|
site = SimpleNamespace(
|
||||||
|
title="Example",
|
||||||
|
chat_color_theme=None,
|
||||||
|
chat_color_theme_inverted=False,
|
||||||
|
icon_type=IconType.IMAGE,
|
||||||
|
icon="file-id",
|
||||||
|
icon_background=None,
|
||||||
|
description=None,
|
||||||
|
copyright=None,
|
||||||
|
privacy_policy=None,
|
||||||
|
custom_disclaimer=None,
|
||||||
|
default_language="en-US",
|
||||||
|
show_workflow_steps=True,
|
||||||
|
use_icon_as_answer_icon=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper:
|
||||||
|
model = Site.model_validate(site)
|
||||||
|
|
||||||
|
assert model.icon_url == "signed"
|
||||||
|
mock_helper.assert_called_once_with("file-id")
|
||||||
|
|
||||||
|
|
||||||
|
def test_site_icon_url_is_none_for_non_image_icon():
|
||||||
|
site = SimpleNamespace(
|
||||||
|
title="Example",
|
||||||
|
chat_color_theme=None,
|
||||||
|
chat_color_theme_inverted=False,
|
||||||
|
icon_type=IconType.EMOJI,
|
||||||
|
icon="file-id",
|
||||||
|
icon_background=None,
|
||||||
|
description=None,
|
||||||
|
copyright=None,
|
||||||
|
privacy_policy=None,
|
||||||
|
custom_disclaimer=None,
|
||||||
|
default_language="en-US",
|
||||||
|
show_workflow_steps=True,
|
||||||
|
use_icon_as_answer_icon=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper:
|
||||||
|
model = Site.model_validate(site)
|
||||||
|
|
||||||
|
assert model.icon_url is None
|
||||||
|
mock_helper.assert_not_called()
|
||||||
|
|
@ -0,0 +1,254 @@
|
||||||
|
"""
|
||||||
|
Unit tests for XSS prevention in App payloads.
|
||||||
|
|
||||||
|
This test module validates that HTML tags, JavaScript, and other potentially
|
||||||
|
dangerous content are rejected in App names and descriptions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from controllers.console.app.app import CopyAppPayload, CreateAppPayload, UpdateAppPayload
|
||||||
|
|
||||||
|
|
||||||
|
class TestXSSPreventionUnit:
|
||||||
|
"""Unit tests for XSS prevention in App payloads."""
|
||||||
|
|
||||||
|
def test_create_app_valid_names(self):
|
||||||
|
"""Test CreateAppPayload with valid app names."""
|
||||||
|
# Normal app names should be valid
|
||||||
|
valid_names = [
|
||||||
|
"My App",
|
||||||
|
"Test App 123",
|
||||||
|
"App with - dash",
|
||||||
|
"App with _ underscore",
|
||||||
|
"App with + plus",
|
||||||
|
"App with () parentheses",
|
||||||
|
"App with [] brackets",
|
||||||
|
"App with {} braces",
|
||||||
|
"App with ! exclamation",
|
||||||
|
"App with @ at",
|
||||||
|
"App with # hash",
|
||||||
|
"App with $ dollar",
|
||||||
|
"App with % percent",
|
||||||
|
"App with ^ caret",
|
||||||
|
"App with & ampersand",
|
||||||
|
"App with * asterisk",
|
||||||
|
"Unicode: 测试应用",
|
||||||
|
"Emoji: 🤖",
|
||||||
|
"Mixed: Test 测试 123",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in valid_names:
|
||||||
|
payload = CreateAppPayload(
|
||||||
|
name=name,
|
||||||
|
mode="chat",
|
||||||
|
)
|
||||||
|
assert payload.name == name
|
||||||
|
|
||||||
|
def test_create_app_xss_script_tags(self):
|
||||||
|
"""Test CreateAppPayload rejects script tags."""
|
||||||
|
xss_payloads = [
|
||||||
|
"<script>alert(document.cookie)</script>",
|
||||||
|
"<Script>alert(1)</Script>",
|
||||||
|
"<SCRIPT>alert('XSS')</SCRIPT>",
|
||||||
|
"<script>alert(String.fromCharCode(88,83,83))</script>",
|
||||||
|
"<script src='evil.js'></script>",
|
||||||
|
"<script>document.location='http://evil.com'</script>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_payloads:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(name=name, mode="chat")
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_create_app_xss_iframe_tags(self):
|
||||||
|
"""Test CreateAppPayload rejects iframe tags."""
|
||||||
|
xss_payloads = [
|
||||||
|
"<iframe src='evil.com'></iframe>",
|
||||||
|
"<Iframe srcdoc='<script>alert(1)</script>'></iframe>",
|
||||||
|
"<IFRAME src='javascript:alert(1)'></iframe>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_payloads:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(name=name, mode="chat")
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_create_app_xss_javascript_protocol(self):
|
||||||
|
"""Test CreateAppPayload rejects javascript: protocol."""
|
||||||
|
xss_payloads = [
|
||||||
|
"javascript:alert(1)",
|
||||||
|
"JAVASCRIPT:alert(1)",
|
||||||
|
"JavaScript:alert(document.cookie)",
|
||||||
|
"javascript:void(0)",
|
||||||
|
"javascript://comment%0Aalert(1)",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_payloads:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(name=name, mode="chat")
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_create_app_xss_svg_onload(self):
|
||||||
|
"""Test CreateAppPayload rejects SVG with onload."""
|
||||||
|
xss_payloads = [
|
||||||
|
"<svg onload=alert(1)>",
|
||||||
|
"<SVG ONLOAD=alert(1)>",
|
||||||
|
"<svg/x/onload=alert(1)>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_payloads:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(name=name, mode="chat")
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_create_app_xss_event_handlers(self):
|
||||||
|
"""Test CreateAppPayload rejects HTML event handlers."""
|
||||||
|
xss_payloads = [
|
||||||
|
"<div onclick=alert(1)>",
|
||||||
|
"<img onerror=alert(1)>",
|
||||||
|
"<body onload=alert(1)>",
|
||||||
|
"<input onfocus=alert(1)>",
|
||||||
|
"<a onmouseover=alert(1)>",
|
||||||
|
"<DIV ONCLICK=alert(1)>",
|
||||||
|
"<img src=x onerror=alert(1)>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_payloads:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(name=name, mode="chat")
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_create_app_xss_object_embed(self):
|
||||||
|
"""Test CreateAppPayload rejects object and embed tags."""
|
||||||
|
xss_payloads = [
|
||||||
|
"<object data='evil.swf'></object>",
|
||||||
|
"<embed src='evil.swf'>",
|
||||||
|
"<OBJECT data='javascript:alert(1)'></OBJECT>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_payloads:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(name=name, mode="chat")
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_create_app_xss_link_javascript(self):
|
||||||
|
"""Test CreateAppPayload rejects link tags with javascript."""
|
||||||
|
xss_payloads = [
|
||||||
|
"<link href='javascript:alert(1)'>",
|
||||||
|
"<LINK HREF='javascript:alert(1)'>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_payloads:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(name=name, mode="chat")
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_create_app_xss_in_description(self):
|
||||||
|
"""Test CreateAppPayload rejects XSS in description."""
|
||||||
|
xss_descriptions = [
|
||||||
|
"<script>alert(1)</script>",
|
||||||
|
"javascript:alert(1)",
|
||||||
|
"<img onerror=alert(1)>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for description in xss_descriptions:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(
|
||||||
|
name="Valid Name",
|
||||||
|
mode="chat",
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_create_app_valid_descriptions(self):
|
||||||
|
"""Test CreateAppPayload with valid descriptions."""
|
||||||
|
valid_descriptions = [
|
||||||
|
"A simple description",
|
||||||
|
"Description with < and > symbols",
|
||||||
|
"Description with & ampersand",
|
||||||
|
"Description with 'quotes' and \"double quotes\"",
|
||||||
|
"Description with / slashes",
|
||||||
|
"Description with \\ backslashes",
|
||||||
|
"Description with ; semicolons",
|
||||||
|
"Unicode: 这是一个描述",
|
||||||
|
"Emoji: 🎉🚀",
|
||||||
|
]
|
||||||
|
|
||||||
|
for description in valid_descriptions:
|
||||||
|
payload = CreateAppPayload(
|
||||||
|
name="Valid App Name",
|
||||||
|
mode="chat",
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
assert payload.description == description
|
||||||
|
|
||||||
|
def test_create_app_none_description(self):
|
||||||
|
"""Test CreateAppPayload with None description."""
|
||||||
|
payload = CreateAppPayload(
|
||||||
|
name="Valid App Name",
|
||||||
|
mode="chat",
|
||||||
|
description=None,
|
||||||
|
)
|
||||||
|
assert payload.description is None
|
||||||
|
|
||||||
|
def test_update_app_xss_prevention(self):
|
||||||
|
"""Test UpdateAppPayload also prevents XSS."""
|
||||||
|
xss_names = [
|
||||||
|
"<script>alert(1)</script>",
|
||||||
|
"javascript:alert(1)",
|
||||||
|
"<img onerror=alert(1)>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_names:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
UpdateAppPayload(name=name)
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_update_app_valid_names(self):
|
||||||
|
"""Test UpdateAppPayload with valid names."""
|
||||||
|
payload = UpdateAppPayload(name="Valid Updated Name")
|
||||||
|
assert payload.name == "Valid Updated Name"
|
||||||
|
|
||||||
|
def test_copy_app_xss_prevention(self):
|
||||||
|
"""Test CopyAppPayload also prevents XSS."""
|
||||||
|
xss_names = [
|
||||||
|
"<script>alert(1)</script>",
|
||||||
|
"javascript:alert(1)",
|
||||||
|
"<img onerror=alert(1)>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in xss_names:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CopyAppPayload(name=name)
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_copy_app_valid_names(self):
|
||||||
|
"""Test CopyAppPayload with valid names."""
|
||||||
|
payload = CopyAppPayload(name="Valid Copy Name")
|
||||||
|
assert payload.name == "Valid Copy Name"
|
||||||
|
|
||||||
|
def test_copy_app_none_name(self):
|
||||||
|
"""Test CopyAppPayload with None name (should be allowed)."""
|
||||||
|
payload = CopyAppPayload(name=None)
|
||||||
|
assert payload.name is None
|
||||||
|
|
||||||
|
def test_edge_case_angle_brackets_content(self):
|
||||||
|
"""Test that angle brackets with actual content are rejected."""
|
||||||
|
# Angle brackets without valid HTML-like patterns should be checked
|
||||||
|
# The regex pattern <.*?on\w+\s*= should catch event handlers
|
||||||
|
# But let's verify other patterns too
|
||||||
|
|
||||||
|
# Valid: angle brackets used as symbols (not matched by our patterns)
|
||||||
|
# Our patterns specifically look for dangerous constructs
|
||||||
|
|
||||||
|
# Invalid: actual HTML tags with event handlers
|
||||||
|
invalid_names = [
|
||||||
|
"<div onclick=xss>",
|
||||||
|
"<img src=x onerror=alert(1)>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in invalid_names:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
CreateAppPayload(name=name, mode="chat")
|
||||||
|
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||||
|
|
@ -171,7 +171,7 @@ class TestOAuthCallback:
|
||||||
):
|
):
|
||||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||||
mock_generate_account.return_value = oauth_setup["account"]
|
mock_generate_account.return_value = (oauth_setup["account"], True)
|
||||||
mock_account_service.login.return_value = oauth_setup["token_pair"]
|
mock_account_service.login.return_value = oauth_setup["token_pair"]
|
||||||
|
|
||||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||||
|
|
@ -179,7 +179,7 @@ class TestOAuthCallback:
|
||||||
|
|
||||||
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
|
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
|
||||||
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
|
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
|
||||||
mock_redirect.assert_called_once_with("http://localhost:3000")
|
mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=true")
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("exception", "expected_error"),
|
("exception", "expected_error"),
|
||||||
|
|
@ -223,7 +223,7 @@ class TestOAuthCallback:
|
||||||
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
||||||
(
|
(
|
||||||
AccountStatus.CLOSED.value,
|
AccountStatus.CLOSED.value,
|
||||||
"http://localhost:3000",
|
"http://localhost:3000?oauth_new_user=false",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -260,7 +260,7 @@ class TestOAuthCallback:
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
account.status = account_status
|
account.status = account_status
|
||||||
account.id = "123"
|
account.id = "123"
|
||||||
mock_generate_account.return_value = account
|
mock_generate_account.return_value = (account, False)
|
||||||
|
|
||||||
# Mock login for CLOSED status
|
# Mock login for CLOSED status
|
||||||
mock_token_pair = MagicMock()
|
mock_token_pair = MagicMock()
|
||||||
|
|
@ -296,7 +296,7 @@ class TestOAuthCallback:
|
||||||
|
|
||||||
mock_account = MagicMock()
|
mock_account = MagicMock()
|
||||||
mock_account.status = AccountStatus.PENDING
|
mock_account.status = AccountStatus.PENDING
|
||||||
mock_generate_account.return_value = mock_account
|
mock_generate_account.return_value = (mock_account, False)
|
||||||
|
|
||||||
mock_token_pair = MagicMock()
|
mock_token_pair = MagicMock()
|
||||||
mock_token_pair.access_token = "jwt_access_token"
|
mock_token_pair.access_token = "jwt_access_token"
|
||||||
|
|
@ -360,7 +360,7 @@ class TestOAuthCallback:
|
||||||
closed_account.status = AccountStatus.CLOSED
|
closed_account.status = AccountStatus.CLOSED
|
||||||
closed_account.id = "123"
|
closed_account.id = "123"
|
||||||
closed_account.name = "Closed Account"
|
closed_account.name = "Closed Account"
|
||||||
mock_generate_account.return_value = closed_account
|
mock_generate_account.return_value = (closed_account, False)
|
||||||
|
|
||||||
# Mock successful login (current behavior)
|
# Mock successful login (current behavior)
|
||||||
mock_token_pair = MagicMock()
|
mock_token_pair = MagicMock()
|
||||||
|
|
@ -374,7 +374,7 @@ class TestOAuthCallback:
|
||||||
resource.get("github")
|
resource.get("github")
|
||||||
|
|
||||||
# Verify current behavior: login succeeds (this is NOT ideal)
|
# Verify current behavior: login succeeds (this is NOT ideal)
|
||||||
mock_redirect.assert_called_once_with("http://localhost:3000")
|
mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=false")
|
||||||
mock_account_service.login.assert_called_once()
|
mock_account_service.login.assert_called_once()
|
||||||
|
|
||||||
# Document expected behavior in comments:
|
# Document expected behavior in comments:
|
||||||
|
|
@ -458,8 +458,9 @@ class TestAccountGeneration:
|
||||||
with pytest.raises(AccountRegisterError):
|
with pytest.raises(AccountRegisterError):
|
||||||
_generate_account("github", user_info)
|
_generate_account("github", user_info)
|
||||||
else:
|
else:
|
||||||
result = _generate_account("github", user_info)
|
result, oauth_new_user = _generate_account("github", user_info)
|
||||||
assert result == mock_account
|
assert result == mock_account
|
||||||
|
assert oauth_new_user == should_create
|
||||||
|
|
||||||
if should_create:
|
if should_create:
|
||||||
mock_register_service.register.assert_called_once_with(
|
mock_register_service.register.assert_called_once_with(
|
||||||
|
|
@ -490,9 +491,10 @@ class TestAccountGeneration:
|
||||||
mock_tenant_service.create_tenant.return_value = mock_new_tenant
|
mock_tenant_service.create_tenant.return_value = mock_new_tenant
|
||||||
|
|
||||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||||
result = _generate_account("github", user_info)
|
result, oauth_new_user = _generate_account("github", user_info)
|
||||||
|
|
||||||
assert result == mock_account
|
assert result == mock_account
|
||||||
|
assert oauth_new_user is False
|
||||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
||||||
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
||||||
mock_new_tenant, mock_account, role="owner"
|
mock_new_tenant, mock_account, role="owner"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Unit tests for load balancing credential validation APIs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask.views import MethodView
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
from models.account import TenantAccountRole
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app() -> Flask:
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Reload controller module with lightweight decorators for testing."""
|
||||||
|
|
||||||
|
from controllers.console import console_ns, wraps
|
||||||
|
from libs import login
|
||||||
|
|
||||||
|
def _noop(func):
|
||||||
|
return func
|
||||||
|
|
||||||
|
monkeypatch.setattr(login, "login_required", _noop)
|
||||||
|
monkeypatch.setattr(wraps, "setup_required", _noop)
|
||||||
|
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
|
||||||
|
|
||||||
|
def _noop_route(*args, **kwargs): # type: ignore[override]
|
||||||
|
def _decorator(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
monkeypatch.setattr(console_ns, "route", _noop_route)
|
||||||
|
|
||||||
|
module_name = "controllers.console.workspace.load_balancing_config"
|
||||||
|
sys.modules.pop(module_name, None)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(current_role=role)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
|
||||||
|
user = _mock_user(role)
|
||||||
|
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
|
||||||
|
mock_service = MagicMock()
|
||||||
|
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
|
||||||
|
return mock_service
|
||||||
|
|
||||||
|
|
||||||
|
def _request_payload():
|
||||||
|
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||||
|
method="POST",
|
||||||
|
json=_request_payload(),
|
||||||
|
):
|
||||||
|
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||||
|
|
||||||
|
assert response == {"result": "success"}
|
||||||
|
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||||
|
tenant_id="tenant-123",
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-4o",
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
credentials={"api_key": "sk-***"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||||
|
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||||
|
method="POST",
|
||||||
|
json=_request_payload(),
|
||||||
|
):
|
||||||
|
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||||
|
|
||||||
|
assert response == {"result": "error", "error": "invalid credentials"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_requires_privileged_role(
|
||||||
|
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||||
|
method="POST",
|
||||||
|
json=_request_payload(),
|
||||||
|
):
|
||||||
|
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
|
||||||
|
with pytest.raises(Forbidden):
|
||||||
|
api.post(provider="openai")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
|
||||||
|
method="POST",
|
||||||
|
json=_request_payload(),
|
||||||
|
):
|
||||||
|
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
|
||||||
|
provider="openai", config_id="cfg-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == {"result": "success"}
|
||||||
|
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||||
|
tenant_id="tenant-123",
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-4o",
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
credentials={"api_key": "sk-***"},
|
||||||
|
config_id="cfg-1",
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue