mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feature/plugin-credential-deletion-option
This commit is contained in:
commit
29690062c6
|
|
@ -0,0 +1,483 @@
|
||||||
|
---
|
||||||
|
name: component-refactoring
|
||||||
|
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Dify Component Refactoring Skill
|
||||||
|
|
||||||
|
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
|
||||||
|
|
||||||
|
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
|
||||||
|
|
||||||
|
## Quick Reference
|
||||||
|
|
||||||
|
### Commands (run from `web/`)
|
||||||
|
|
||||||
|
Use paths relative to `web/` (e.g., `app/components/...`).
|
||||||
|
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd web
|
||||||
|
|
||||||
|
# Generate refactoring prompt
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
|
||||||
|
# Output refactoring analysis as JSON
|
||||||
|
pnpm refactor-component <path> --json
|
||||||
|
|
||||||
|
# Generate testing prompt (after refactoring)
|
||||||
|
pnpm analyze-component <path>
|
||||||
|
|
||||||
|
# Output testing analysis as JSON
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complexity Analysis
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Analyze component complexity
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
|
||||||
|
# Key metrics to check:
|
||||||
|
# - complexity: normalized score 0-100 (target < 50)
|
||||||
|
# - maxComplexity: highest single function complexity
|
||||||
|
# - lineCount: total lines (target < 300)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complexity Score Interpretation
|
||||||
|
|
||||||
|
| Score | Level | Action |
|
||||||
|
|-------|-------|--------|
|
||||||
|
| 0-25 | 🟢 Simple | Ready for testing |
|
||||||
|
| 26-50 | 🟡 Medium | Consider minor refactoring |
|
||||||
|
| 51-75 | 🟠 Complex | **Refactor before testing** |
|
||||||
|
| 76-100 | 🔴 Very Complex | **Must refactor** |
|
||||||
|
|
||||||
|
## Core Refactoring Patterns
|
||||||
|
|
||||||
|
### Pattern 1: Extract Custom Hooks
|
||||||
|
|
||||||
|
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
|
||||||
|
|
||||||
|
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Complex state logic in component
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
// 50+ lines of state management logic...
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to custom hook
|
||||||
|
// hooks/use-model-config.ts
|
||||||
|
export const useModelConfig = (appId: string) => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
// Related state management logic here
|
||||||
|
|
||||||
|
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const { modelConfig, setModelConfig } = useModelConfig(appId)
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
|
||||||
|
- `web/app/components/app/configuration/debug/hooks.tsx`
|
||||||
|
- `web/app/components/workflow/hooks/use-workflow.ts`
|
||||||
|
|
||||||
|
### Pattern 2: Extract Sub-Components
|
||||||
|
|
||||||
|
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
|
||||||
|
|
||||||
|
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Monolithic JSX with multiple sections
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* 100 lines of header UI */}
|
||||||
|
{/* 100 lines of operations UI */}
|
||||||
|
{/* 100 lines of modals */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Split into focused components
|
||||||
|
// app-info/
|
||||||
|
// ├── index.tsx (orchestration only)
|
||||||
|
// ├── app-header.tsx (header UI)
|
||||||
|
// ├── app-operations.tsx (operations UI)
|
||||||
|
// └── app-modals.tsx (modal management)
|
||||||
|
|
||||||
|
const AppInfo = () => {
|
||||||
|
const { showModal, setShowModal } = useAppInfoModals()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<AppHeader appDetail={appDetail} />
|
||||||
|
<AppOperations onAction={handleAction} />
|
||||||
|
<AppModals show={showModal} onClose={() => setShowModal(null)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/app/components/app/configuration/` directory structure
|
||||||
|
- `web/app/components/workflow/nodes/` per-node organization
|
||||||
|
|
||||||
|
### Pattern 3: Simplify Conditional Logic
|
||||||
|
|
||||||
|
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Deeply nested conditionals
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateChatZh />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateChatJa />
|
||||||
|
default:
|
||||||
|
return <TemplateChatEn />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||||
|
// Another 15 lines...
|
||||||
|
}
|
||||||
|
// More conditions...
|
||||||
|
}, [appDetail, locale])
|
||||||
|
|
||||||
|
// ✅ After: Use lookup tables + early returns
|
||||||
|
const TEMPLATE_MAP = {
|
||||||
|
[AppModeEnum.CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateChatJa,
|
||||||
|
default: TemplateChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.ADVANCED_CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||||
|
// ...
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
|
||||||
|
if (!modeTemplates) return null
|
||||||
|
|
||||||
|
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
|
||||||
|
return <TemplateComponent appDetail={appDetail} />
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 4: Extract API/Data Logic
|
||||||
|
|
||||||
|
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||||
|
|
||||||
|
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks. Project is migrating from SWR to React Query.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: API logic in component
|
||||||
|
const MCPServiceCard = () => {
|
||||||
|
const [basicAppConfig, setBasicAppConfig] = useState({})
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isBasicApp && appId) {
|
||||||
|
(async () => {
|
||||||
|
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||||
|
setBasicAppConfig(res?.model_config || {})
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
}, [appId, isBasicApp])
|
||||||
|
|
||||||
|
// More API-related logic...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to data hook using React Query
|
||||||
|
// use-app-config.ts
|
||||||
|
import { useQuery } from '@tanstack/react-query'
|
||||||
|
import { get } from '@/service/base'
|
||||||
|
|
||||||
|
const NAME_SPACE = 'appConfig'
|
||||||
|
|
||||||
|
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
|
||||||
|
return useQuery({
|
||||||
|
enabled: isBasicApp && !!appId,
|
||||||
|
queryKey: [NAME_SPACE, 'detail', appId],
|
||||||
|
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||||
|
select: data => data?.model_config || {},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const MCPServiceCard = () => {
|
||||||
|
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
|
||||||
|
// UI only
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**React Query Best Practices in Dify**:
|
||||||
|
- Define `NAME_SPACE` for query key organization
|
||||||
|
- Use `enabled` option for conditional fetching
|
||||||
|
- Use `select` for data transformation
|
||||||
|
- Export invalidation hooks: `useInvalidXxx`
|
||||||
|
|
||||||
|
**Dify Examples**:
|
||||||
|
- `web/service/use-workflow.ts`
|
||||||
|
- `web/service/use-common.ts`
|
||||||
|
- `web/service/knowledge/use-dataset.ts`
|
||||||
|
- `web/service/knowledge/use-document.ts`
|
||||||
|
|
||||||
|
### Pattern 5: Extract Modal/Dialog Management
|
||||||
|
|
||||||
|
**When**: Component manages multiple modals with complex open/close states.
|
||||||
|
|
||||||
|
**Dify Convention**: Modals should be extracted with their state management.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Multiple modal states in component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const [showEditModal, setShowEditModal] = useState(false)
|
||||||
|
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
|
||||||
|
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
|
||||||
|
const [showSwitchModal, setShowSwitchModal] = useState(false)
|
||||||
|
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
|
||||||
|
// 5+ more modal states...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extract to modal management hook
|
||||||
|
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
|
||||||
|
|
||||||
|
const useAppInfoModals = () => {
|
||||||
|
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||||
|
|
||||||
|
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
|
||||||
|
const closeModal = useCallback(() => setActiveModal(null), [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
activeModal,
|
||||||
|
openModal,
|
||||||
|
closeModal,
|
||||||
|
isOpen: (type: ModalType) => activeModal === type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 6: Extract Form Logic
|
||||||
|
|
||||||
|
**When**: Complex form validation, submission handling, or field transformation.
|
||||||
|
|
||||||
|
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ✅ Use existing form infrastructure
|
||||||
|
import { useAppForm } from '@/app/components/base/form'
|
||||||
|
|
||||||
|
const ConfigForm = () => {
|
||||||
|
const form = useAppForm({
|
||||||
|
defaultValues: { name: '', description: '' },
|
||||||
|
onSubmit: handleSubmit,
|
||||||
|
})
|
||||||
|
|
||||||
|
return <form.Provider>...</form.Provider>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dify-Specific Refactoring Guidelines
|
||||||
|
|
||||||
|
### 1. Context Provider Extraction
|
||||||
|
|
||||||
|
**When**: Component provides complex context values with multiple states.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Large context value object
|
||||||
|
const value = {
|
||||||
|
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
|
||||||
|
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
|
||||||
|
// 50+ more properties...
|
||||||
|
}
|
||||||
|
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
|
||||||
|
|
||||||
|
// ✅ After: Split into domain-specific contexts
|
||||||
|
<ModelConfigProvider value={modelConfigValue}>
|
||||||
|
<DatasetConfigProvider value={datasetConfigValue}>
|
||||||
|
<UIConfigProvider value={uiConfigValue}>
|
||||||
|
{children}
|
||||||
|
</UIConfigProvider>
|
||||||
|
</DatasetConfigProvider>
|
||||||
|
</ModelConfigProvider>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dify Reference**: `web/context/` directory structure
|
||||||
|
|
||||||
|
### 2. Workflow Node Components
|
||||||
|
|
||||||
|
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Keep node logic in `use-interactions.ts`
|
||||||
|
- Extract panel UI to separate files
|
||||||
|
- Use `_base` components for common patterns
|
||||||
|
|
||||||
|
```
|
||||||
|
nodes/<node-type>/
|
||||||
|
├── index.tsx # Node registration
|
||||||
|
├── node.tsx # Node visual component
|
||||||
|
├── panel.tsx # Configuration panel
|
||||||
|
├── use-interactions.ts # Node-specific hooks
|
||||||
|
└── types.ts # Type definitions
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Configuration Components
|
||||||
|
|
||||||
|
**When**: Refactoring app configuration components.
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Separate config sections into subdirectories
|
||||||
|
- Use existing patterns from `web/app/components/app/configuration/`
|
||||||
|
- Keep feature toggles in dedicated components
|
||||||
|
|
||||||
|
### 4. Tool/Plugin Components
|
||||||
|
|
||||||
|
**When**: Refactoring tool-related components (`web/app/components/tools/`).
|
||||||
|
|
||||||
|
**Conventions**:
|
||||||
|
- Follow existing modal patterns
|
||||||
|
- Use service hooks from `web/service/use-tools.ts`
|
||||||
|
- Keep provider-specific logic isolated
|
||||||
|
|
||||||
|
## Refactoring Workflow
|
||||||
|
|
||||||
|
### Step 1: Generate Refactoring Prompt
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
```
|
||||||
|
|
||||||
|
This command will:
|
||||||
|
- Analyze component complexity and features
|
||||||
|
- Identify specific refactoring actions needed
|
||||||
|
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
|
||||||
|
- Provide detailed requirements based on detected patterns
|
||||||
|
|
||||||
|
### Step 2: Analyze Details
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
```
|
||||||
|
|
||||||
|
Identify:
|
||||||
|
- Total complexity score
|
||||||
|
- Max function complexity
|
||||||
|
- Line count
|
||||||
|
- Features detected (state, effects, API, etc.)
|
||||||
|
|
||||||
|
### Step 3: Plan
|
||||||
|
|
||||||
|
Create a refactoring plan based on detected features:
|
||||||
|
|
||||||
|
| Detected Feature | Refactoring Action |
|
||||||
|
|------------------|-------------------|
|
||||||
|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
|
||||||
|
| `hasAPI: true` | Extract data/service hook |
|
||||||
|
| `hasEvents: true` (many) | Extract event handlers |
|
||||||
|
| `lineCount > 300` | Split into sub-components |
|
||||||
|
| `maxComplexity > 50` | Simplify conditional logic |
|
||||||
|
|
||||||
|
### Step 4: Execute Incrementally
|
||||||
|
|
||||||
|
1. **Extract one piece at a time**
|
||||||
|
2. **Run lint, type-check, and tests after each extraction**
|
||||||
|
3. **Verify functionality before next step**
|
||||||
|
|
||||||
|
```
|
||||||
|
For each extraction:
|
||||||
|
┌────────────────────────────────────────┐
|
||||||
|
│ 1. Extract code │
|
||||||
|
│ 2. Run: pnpm lint:fix │
|
||||||
|
│ 3. Run: pnpm type-check:tsgo │
|
||||||
|
│ 4. Run: pnpm test │
|
||||||
|
│ 5. Test functionality manually │
|
||||||
|
│ 6. PASS? → Next extraction │
|
||||||
|
│ FAIL? → Fix before continuing │
|
||||||
|
└────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 5: Verify
|
||||||
|
|
||||||
|
After refactoring:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Re-run refactor command to verify improvements
|
||||||
|
pnpm refactor-component <path>
|
||||||
|
|
||||||
|
# If complexity < 25 and lines < 200, you'll see:
|
||||||
|
# ✅ COMPONENT IS WELL-STRUCTURED
|
||||||
|
|
||||||
|
# For detailed metrics:
|
||||||
|
pnpm analyze-component <path> --json
|
||||||
|
|
||||||
|
# Target metrics:
|
||||||
|
# - complexity < 50
|
||||||
|
# - lineCount < 300
|
||||||
|
# - maxComplexity < 30
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Mistakes to Avoid
|
||||||
|
|
||||||
|
### ❌ Over-Engineering
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Too many tiny hooks
|
||||||
|
const useButtonText = () => useState('Click')
|
||||||
|
const useButtonDisabled = () => useState(false)
|
||||||
|
const useButtonLoading = () => useState(false)
|
||||||
|
|
||||||
|
// ✅ Cohesive hook with related state
|
||||||
|
const useButtonState = () => {
|
||||||
|
const [text, setText] = useState('Click')
|
||||||
|
const [disabled, setDisabled] = useState(false)
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
return { text, setText, disabled, setDisabled, loading, setLoading }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### ❌ Breaking Existing Patterns
|
||||||
|
|
||||||
|
- Follow existing directory structures
|
||||||
|
- Maintain naming conventions
|
||||||
|
- Preserve export patterns for compatibility
|
||||||
|
|
||||||
|
### ❌ Premature Abstraction
|
||||||
|
|
||||||
|
- Only extract when there's clear complexity benefit
|
||||||
|
- Don't create abstractions for single-use code
|
||||||
|
- Keep refactored code in the same domain area
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
### Dify Codebase Examples
|
||||||
|
|
||||||
|
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
|
||||||
|
- **Component splitting**: `web/app/components/app/configuration/`
|
||||||
|
- **Service hooks**: `web/service/use-*.ts`
|
||||||
|
- **Workflow patterns**: `web/app/components/workflow/hooks/`
|
||||||
|
- **Form patterns**: `web/app/components/base/form/`
|
||||||
|
|
||||||
|
### Related Skills
|
||||||
|
|
||||||
|
- `frontend-testing` - For testing refactored components
|
||||||
|
- `web/testing/testing.md` - Testing specification
|
||||||
|
|
@ -0,0 +1,493 @@
|
||||||
|
# Complexity Reduction Patterns
|
||||||
|
|
||||||
|
This document provides patterns for reducing cognitive complexity in Dify React components.
|
||||||
|
|
||||||
|
## Understanding Complexity
|
||||||
|
|
||||||
|
### SonarJS Cognitive Complexity
|
||||||
|
|
||||||
|
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
|
||||||
|
|
||||||
|
- **Total Complexity**: Sum of all functions' complexity in the file
|
||||||
|
- **Max Complexity**: Highest single function complexity
|
||||||
|
|
||||||
|
### What Increases Complexity
|
||||||
|
|
||||||
|
| Pattern | Complexity Impact |
|
||||||
|
|---------|-------------------|
|
||||||
|
| `if/else` | +1 per branch |
|
||||||
|
| Nested conditions | +1 per nesting level |
|
||||||
|
| `switch/case` | +1 per case |
|
||||||
|
| `for/while/do` | +1 per loop |
|
||||||
|
| `&&`/`||` chains | +1 per operator |
|
||||||
|
| Nested callbacks | +1 per nesting level |
|
||||||
|
| `try/catch` | +1 per catch |
|
||||||
|
| Ternary expressions | +1 per nesting |
|
||||||
|
|
||||||
|
## Pattern 1: Replace Conditionals with Lookup Tables
|
||||||
|
|
||||||
|
**Before** (complexity: ~15):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateChatZh appDetail={appDetail} />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateChatJa appDetail={appDetail} />
|
||||||
|
default:
|
||||||
|
return <TemplateChatEn appDetail={appDetail} />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||||
|
switch (locale) {
|
||||||
|
case LanguagesSupported[1]:
|
||||||
|
return <TemplateAdvancedChatZh appDetail={appDetail} />
|
||||||
|
case LanguagesSupported[7]:
|
||||||
|
return <TemplateAdvancedChatJa appDetail={appDetail} />
|
||||||
|
default:
|
||||||
|
return <TemplateAdvancedChatEn appDetail={appDetail} />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
|
||||||
|
// Similar pattern...
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~3):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Define lookup table outside component
|
||||||
|
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
|
||||||
|
[AppModeEnum.CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateChatJa,
|
||||||
|
default: TemplateChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.ADVANCED_CHAT]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
|
||||||
|
default: TemplateAdvancedChatEn,
|
||||||
|
},
|
||||||
|
[AppModeEnum.WORKFLOW]: {
|
||||||
|
[LanguagesSupported[1]]: TemplateWorkflowZh,
|
||||||
|
[LanguagesSupported[7]]: TemplateWorkflowJa,
|
||||||
|
default: TemplateWorkflowEn,
|
||||||
|
},
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean component logic
|
||||||
|
const Template = useMemo(() => {
|
||||||
|
if (!appDetail?.mode) return null
|
||||||
|
|
||||||
|
const templates = TEMPLATE_MAP[appDetail.mode]
|
||||||
|
if (!templates) return null
|
||||||
|
|
||||||
|
const TemplateComponent = templates[locale] ?? templates.default
|
||||||
|
return <TemplateComponent appDetail={appDetail} />
|
||||||
|
}, [appDetail, locale])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 2: Use Early Returns
|
||||||
|
|
||||||
|
**Before** (complexity: ~10):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const handleSubmit = () => {
|
||||||
|
if (isValid) {
|
||||||
|
if (hasChanges) {
|
||||||
|
if (isConnected) {
|
||||||
|
submitData()
|
||||||
|
} else {
|
||||||
|
showConnectionError()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showNoChangesMessage()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showValidationError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~4):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const handleSubmit = () => {
|
||||||
|
if (!isValid) {
|
||||||
|
showValidationError()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasChanges) {
|
||||||
|
showNoChangesMessage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isConnected) {
|
||||||
|
showConnectionError()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
submitData()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 3: Extract Complex Conditions
|
||||||
|
|
||||||
|
**Before** (complexity: high):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const canPublish = (() => {
|
||||||
|
if (mode !== AppModeEnum.COMPLETION) {
|
||||||
|
if (!isAdvancedMode)
|
||||||
|
return true
|
||||||
|
|
||||||
|
if (modelModeType === ModelModeType.completion) {
|
||||||
|
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
|
||||||
|
return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !promptEmpty
|
||||||
|
})()
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract to named functions
|
||||||
|
const canPublishInCompletionMode = () => !promptEmpty
|
||||||
|
|
||||||
|
const canPublishInChatMode = () => {
|
||||||
|
if (!isAdvancedMode) return true
|
||||||
|
if (modelModeType !== ModelModeType.completion) return true
|
||||||
|
return hasSetBlockStatus.history && hasSetBlockStatus.query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean main logic
|
||||||
|
const canPublish = mode === AppModeEnum.COMPLETION
|
||||||
|
? canPublishInCompletionMode()
|
||||||
|
: canPublishInChatMode()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 4: Replace Chained Ternaries
|
||||||
|
|
||||||
|
**Before** (complexity: ~5):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const statusText = serverActivated
|
||||||
|
? t('status.running')
|
||||||
|
: serverPublished
|
||||||
|
? t('status.inactive')
|
||||||
|
: appUnpublished
|
||||||
|
? t('status.unpublished')
|
||||||
|
: t('status.notConfigured')
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~2):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const getStatusText = () => {
|
||||||
|
if (serverActivated) return t('status.running')
|
||||||
|
if (serverPublished) return t('status.inactive')
|
||||||
|
if (appUnpublished) return t('status.unpublished')
|
||||||
|
return t('status.notConfigured')
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusText = getStatusText()
|
||||||
|
```
|
||||||
|
|
||||||
|
Or use lookup:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const STATUS_TEXT_MAP = {
|
||||||
|
running: 'status.running',
|
||||||
|
inactive: 'status.inactive',
|
||||||
|
unpublished: 'status.unpublished',
|
||||||
|
notConfigured: 'status.notConfigured',
|
||||||
|
} as const
|
||||||
|
|
||||||
|
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
|
||||||
|
if (serverActivated) return 'running'
|
||||||
|
if (serverPublished) return 'inactive'
|
||||||
|
if (appUnpublished) return 'unpublished'
|
||||||
|
return 'notConfigured'
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 5: Flatten Nested Loops
|
||||||
|
|
||||||
|
**Before** (complexity: high):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const processData = (items: Item[]) => {
|
||||||
|
const results: ProcessedItem[] = []
|
||||||
|
|
||||||
|
for (const item of items) {
|
||||||
|
if (item.isValid) {
|
||||||
|
for (const child of item.children) {
|
||||||
|
if (child.isActive) {
|
||||||
|
for (const prop of child.properties) {
|
||||||
|
if (prop.value !== null) {
|
||||||
|
results.push({
|
||||||
|
itemId: item.id,
|
||||||
|
childId: child.id,
|
||||||
|
propValue: prop.value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Use functional approach
|
||||||
|
const processData = (items: Item[]) => {
|
||||||
|
return items
|
||||||
|
.filter(item => item.isValid)
|
||||||
|
.flatMap(item =>
|
||||||
|
item.children
|
||||||
|
.filter(child => child.isActive)
|
||||||
|
.flatMap(child =>
|
||||||
|
child.properties
|
||||||
|
.filter(prop => prop.value !== null)
|
||||||
|
.map(prop => ({
|
||||||
|
itemId: item.id,
|
||||||
|
childId: child.id,
|
||||||
|
propValue: prop.value,
|
||||||
|
}))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 6: Extract Event Handler Logic
|
||||||
|
|
||||||
|
**Before** (complexity: high in component):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const Component = () => {
|
||||||
|
const handleSelect = (data: DataSet[]) => {
|
||||||
|
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
|
||||||
|
hideSelectDataSet()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
formattingChangedDispatcher()
|
||||||
|
let newDatasets = data
|
||||||
|
if (data.find(item => !item.name)) {
|
||||||
|
const newSelected = produce(data, (draft) => {
|
||||||
|
data.forEach((item, index) => {
|
||||||
|
if (!item.name) {
|
||||||
|
const newItem = dataSets.find(i => i.id === item.id)
|
||||||
|
if (newItem)
|
||||||
|
draft[index] = newItem
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
setDataSets(newSelected)
|
||||||
|
newDatasets = newSelected
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
setDataSets(data)
|
||||||
|
}
|
||||||
|
hideSelectDataSet()
|
||||||
|
|
||||||
|
// 40 more lines of logic...
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: lower):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract to hook or utility
|
||||||
|
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
|
||||||
|
const normalizeSelection = (data: DataSet[]) => {
|
||||||
|
const hasUnloadedItem = data.some(item => !item.name)
|
||||||
|
if (!hasUnloadedItem) return data
|
||||||
|
|
||||||
|
return produce(data, (draft) => {
|
||||||
|
data.forEach((item, index) => {
|
||||||
|
if (!item.name) {
|
||||||
|
const existing = dataSets.find(i => i.id === item.id)
|
||||||
|
if (existing) draft[index] = existing
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasSelectionChanged = (newData: DataSet[]) => {
|
||||||
|
return !isEqual(
|
||||||
|
newData.map(item => item.id),
|
||||||
|
dataSets.map(item => item.id)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return { normalizeSelection, hasSelectionChanged }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component becomes cleaner
|
||||||
|
const Component = () => {
|
||||||
|
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
|
||||||
|
|
||||||
|
const handleSelect = (data: DataSet[]) => {
|
||||||
|
if (!hasSelectionChanged(data)) {
|
||||||
|
hideSelectDataSet()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
formattingChangedDispatcher()
|
||||||
|
const normalized = normalizeSelection(data)
|
||||||
|
setDataSets(normalized)
|
||||||
|
hideSelectDataSet()
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 7: Reduce Boolean Logic Complexity
|
||||||
|
|
||||||
|
**Before** (complexity: ~8):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const toggleDisabled = hasInsufficientPermissions
|
||||||
|
|| appUnpublished
|
||||||
|
|| missingStartNode
|
||||||
|
|| triggerModeDisabled
|
||||||
|
|| (isAdvancedApp && !currentWorkflow?.graph)
|
||||||
|
|| (isBasicApp && !basicAppConfig.updated_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: ~3):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Extract meaningful boolean functions
|
||||||
|
const isAppReady = () => {
|
||||||
|
if (isAdvancedApp) return !!currentWorkflow?.graph
|
||||||
|
return !!basicAppConfig.updated_at
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasRequiredPermissions = () => {
|
||||||
|
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
|
||||||
|
}
|
||||||
|
|
||||||
|
const canToggle = () => {
|
||||||
|
if (!hasRequiredPermissions()) return false
|
||||||
|
if (!isAppReady()) return false
|
||||||
|
if (missingStartNode) return false
|
||||||
|
if (triggerModeDisabled) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleDisabled = !canToggle()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern 8: Simplify useMemo/useCallback Dependencies
|
||||||
|
|
||||||
|
**Before** (complexity: multiple recalculations):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const payload = useMemo(() => {
|
||||||
|
let parameters: Parameter[] = []
|
||||||
|
let outputParameters: OutputParameter[] = []
|
||||||
|
|
||||||
|
if (!published) {
|
||||||
|
parameters = (inputs || []).map((item) => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
form: 'llm',
|
||||||
|
required: item.required,
|
||||||
|
type: item.type,
|
||||||
|
}))
|
||||||
|
outputParameters = (outputs || []).map((item) => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
type: item.value_type,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
else if (detail && detail.tool) {
|
||||||
|
parameters = (inputs || []).map((item) => ({
|
||||||
|
// Complex transformation...
|
||||||
|
}))
|
||||||
|
outputParameters = (outputs || []).map((item) => ({
|
||||||
|
// Complex transformation...
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
icon: detail?.icon || icon,
|
||||||
|
label: detail?.label || name,
|
||||||
|
// ...more fields
|
||||||
|
}
|
||||||
|
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (complexity: separated concerns):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Separate transformations
|
||||||
|
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
|
||||||
|
return useMemo(() => {
|
||||||
|
if (!published) {
|
||||||
|
return inputs.map(item => ({
|
||||||
|
name: item.variable,
|
||||||
|
description: '',
|
||||||
|
form: 'llm',
|
||||||
|
required: item.required,
|
||||||
|
type: item.type,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!detail?.tool) return []
|
||||||
|
|
||||||
|
return inputs.map(item => ({
|
||||||
|
name: item.variable,
|
||||||
|
required: item.required,
|
||||||
|
type: item.type === 'paragraph' ? 'string' : item.type,
|
||||||
|
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
|
||||||
|
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
|
||||||
|
}))
|
||||||
|
}, [inputs, detail, published])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Component uses hook
|
||||||
|
const parameters = useParameterTransform(inputs, detail, published)
|
||||||
|
const outputParameters = useOutputTransform(outputs, detail, published)
|
||||||
|
|
||||||
|
const payload = useMemo(() => ({
|
||||||
|
icon: detail?.icon || icon,
|
||||||
|
label: detail?.label || name,
|
||||||
|
parameters,
|
||||||
|
outputParameters,
|
||||||
|
// ...
|
||||||
|
}), [detail, icon, name, parameters, outputParameters])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Target Metrics After Refactoring
|
||||||
|
|
||||||
|
| Metric | Target |
|
||||||
|
|--------|--------|
|
||||||
|
| Total Complexity | < 50 |
|
||||||
|
| Max Function Complexity | < 30 |
|
||||||
|
| Function Length | < 30 lines |
|
||||||
|
| Nesting Depth | ≤ 3 levels |
|
||||||
|
| Conditional Chains | ≤ 3 conditions |
|
||||||
|
|
@ -0,0 +1,477 @@
|
||||||
|
# Component Splitting Patterns
|
||||||
|
|
||||||
|
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
|
||||||
|
|
||||||
|
## When to Split Components
|
||||||
|
|
||||||
|
Split a component when you identify:
|
||||||
|
|
||||||
|
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
|
||||||
|
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
|
||||||
|
1. **Repeated patterns** - Similar UI structures used multiple times
|
||||||
|
1. **300+ lines** - Component exceeds manageable size
|
||||||
|
1. **Modal clusters** - Multiple modals rendered in one component
|
||||||
|
|
||||||
|
## Splitting Strategies
|
||||||
|
|
||||||
|
### Strategy 1: Section-Based Splitting
|
||||||
|
|
||||||
|
Identify visual sections and extract each as a component.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Monolithic component (500+ lines)
|
||||||
|
const ConfigurationPage = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Header Section - 50 lines */}
|
||||||
|
<div className="header">
|
||||||
|
<h1>{t('configuration.title')}</h1>
|
||||||
|
<div className="actions">
|
||||||
|
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||||
|
<ModelParameterModal ... />
|
||||||
|
<AppPublisher ... />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Config Section - 200 lines */}
|
||||||
|
<div className="config">
|
||||||
|
<Config />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Debug Section - 150 lines */}
|
||||||
|
<div className="debug">
|
||||||
|
<Debug ... />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Modals Section - 100 lines */}
|
||||||
|
{showSelectDataSet && <SelectDataSet ... />}
|
||||||
|
{showHistoryModal && <EditHistoryModal ... />}
|
||||||
|
{showUseGPT4Confirm && <Confirm ... />}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Split into focused components
|
||||||
|
// configuration/
|
||||||
|
// ├── index.tsx (orchestration)
|
||||||
|
// ├── configuration-header.tsx
|
||||||
|
// ├── configuration-content.tsx
|
||||||
|
// ├── configuration-debug.tsx
|
||||||
|
// └── configuration-modals.tsx
|
||||||
|
|
||||||
|
// configuration-header.tsx
|
||||||
|
interface ConfigurationHeaderProps {
|
||||||
|
isAdvancedMode: boolean
|
||||||
|
onPublish: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
|
||||||
|
isAdvancedMode,
|
||||||
|
onPublish,
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="header">
|
||||||
|
<h1>{t('configuration.title')}</h1>
|
||||||
|
<div className="actions">
|
||||||
|
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||||
|
<ModelParameterModal ... />
|
||||||
|
<AppPublisher onPublish={onPublish} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// index.tsx (orchestration only)
|
||||||
|
const ConfigurationPage = () => {
|
||||||
|
const { modelConfig, setModelConfig } = useModelConfig()
|
||||||
|
const { activeModal, openModal, closeModal } = useModalState()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<ConfigurationHeader
|
||||||
|
isAdvancedMode={isAdvancedMode}
|
||||||
|
onPublish={handlePublish}
|
||||||
|
/>
|
||||||
|
<ConfigurationContent
|
||||||
|
modelConfig={modelConfig}
|
||||||
|
onConfigChange={setModelConfig}
|
||||||
|
/>
|
||||||
|
{!isMobile && (
|
||||||
|
<ConfigurationDebug
|
||||||
|
inputs={inputs}
|
||||||
|
onSetting={handleSetting}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<ConfigurationModals
|
||||||
|
activeModal={activeModal}
|
||||||
|
onClose={closeModal}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 2: Conditional Block Extraction
|
||||||
|
|
||||||
|
Extract large conditional rendering blocks.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Large conditional blocks
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{expand ? (
|
||||||
|
<div className="expanded">
|
||||||
|
{/* 100 lines of expanded view */}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="collapsed">
|
||||||
|
{/* 50 lines of collapsed view */}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Separate view components
|
||||||
|
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="expanded">
|
||||||
|
{/* Clean, focused expanded view */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="collapsed">
|
||||||
|
{/* Clean, focused collapsed view */}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfo = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{expand
|
||||||
|
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
|
||||||
|
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 3: Modal Extraction
|
||||||
|
|
||||||
|
Extract modals with their trigger logic.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Multiple modals in one component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const [showEdit, setShowEdit] = useState(false)
|
||||||
|
const [showDuplicate, setShowDuplicate] = useState(false)
|
||||||
|
const [showDelete, setShowDelete] = useState(false)
|
||||||
|
const [showSwitch, setShowSwitch] = useState(false)
|
||||||
|
|
||||||
|
const onEdit = async (data) => { /* 20 lines */ }
|
||||||
|
const onDuplicate = async (data) => { /* 20 lines */ }
|
||||||
|
const onDelete = async () => { /* 15 lines */ }
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Main content */}
|
||||||
|
|
||||||
|
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
|
||||||
|
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
|
||||||
|
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
|
||||||
|
{showSwitch && <SwitchModal ... />}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Modal manager component
|
||||||
|
// app-info-modals.tsx
|
||||||
|
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
|
||||||
|
|
||||||
|
interface AppInfoModalsProps {
|
||||||
|
appDetail: AppDetail
|
||||||
|
activeModal: ModalType
|
||||||
|
onClose: () => void
|
||||||
|
onSuccess: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const AppInfoModals: FC<AppInfoModalsProps> = ({
|
||||||
|
appDetail,
|
||||||
|
activeModal,
|
||||||
|
onClose,
|
||||||
|
onSuccess,
|
||||||
|
}) => {
|
||||||
|
const handleEdit = async (data) => { /* logic */ }
|
||||||
|
const handleDuplicate = async (data) => { /* logic */ }
|
||||||
|
const handleDelete = async () => { /* logic */ }
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{activeModal === 'edit' && (
|
||||||
|
<EditModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onConfirm={handleEdit}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'duplicate' && (
|
||||||
|
<DuplicateModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onConfirm={handleDuplicate}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'delete' && (
|
||||||
|
<DeleteConfirm
|
||||||
|
onConfirm={handleDelete}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{activeModal === 'switch' && (
|
||||||
|
<SwitchModal
|
||||||
|
appDetail={appDetail}
|
||||||
|
onClose={onClose}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parent component
|
||||||
|
const AppInfo = () => {
|
||||||
|
const { activeModal, openModal, closeModal } = useModalState()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* Main content with openModal triggers */}
|
||||||
|
<Button onClick={() => openModal('edit')}>Edit</Button>
|
||||||
|
|
||||||
|
<AppInfoModals
|
||||||
|
appDetail={appDetail}
|
||||||
|
activeModal={activeModal}
|
||||||
|
onClose={closeModal}
|
||||||
|
onSuccess={handleSuccess}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 4: List Item Extraction
|
||||||
|
|
||||||
|
Extract repeated item rendering.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Before: Inline item rendering
|
||||||
|
const OperationsList = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{operations.map(op => (
|
||||||
|
<div key={op.id} className="operation-item">
|
||||||
|
<span className="icon">{op.icon}</span>
|
||||||
|
<span className="title">{op.title}</span>
|
||||||
|
<span className="description">{op.description}</span>
|
||||||
|
<button onClick={() => op.onClick()}>
|
||||||
|
{op.actionLabel}
|
||||||
|
</button>
|
||||||
|
{op.badge && <Badge>{op.badge}</Badge>}
|
||||||
|
{/* More complex rendering... */}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ After: Extracted item component
|
||||||
|
interface OperationItemProps {
|
||||||
|
operation: Operation
|
||||||
|
onAction: (id: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
|
||||||
|
return (
|
||||||
|
<div className="operation-item">
|
||||||
|
<span className="icon">{operation.icon}</span>
|
||||||
|
<span className="title">{operation.title}</span>
|
||||||
|
<span className="description">{operation.description}</span>
|
||||||
|
<button onClick={() => onAction(operation.id)}>
|
||||||
|
{operation.actionLabel}
|
||||||
|
</button>
|
||||||
|
{operation.badge && <Badge>{operation.badge}</Badge>}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const OperationsList = () => {
|
||||||
|
const handleAction = useCallback((id: string) => {
|
||||||
|
const op = operations.find(o => o.id === id)
|
||||||
|
op?.onClick()
|
||||||
|
}, [operations])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{operations.map(op => (
|
||||||
|
<OperationItem
|
||||||
|
key={op.id}
|
||||||
|
operation={op}
|
||||||
|
onAction={handleAction}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure Patterns
|
||||||
|
|
||||||
|
### Pattern A: Flat Structure (Simple Components)
|
||||||
|
|
||||||
|
For components with 2-3 sub-components:
|
||||||
|
|
||||||
|
```
|
||||||
|
component-name/
|
||||||
|
├── index.tsx # Main component
|
||||||
|
├── sub-component-a.tsx
|
||||||
|
├── sub-component-b.tsx
|
||||||
|
└── types.ts # Shared types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern B: Nested Structure (Complex Components)
|
||||||
|
|
||||||
|
For components with many sub-components:
|
||||||
|
|
||||||
|
```
|
||||||
|
component-name/
|
||||||
|
├── index.tsx # Main orchestration
|
||||||
|
├── types.ts # Shared types
|
||||||
|
├── hooks/
|
||||||
|
│ ├── use-feature-a.ts
|
||||||
|
│ └── use-feature-b.ts
|
||||||
|
├── components/
|
||||||
|
│ ├── header/
|
||||||
|
│ │ └── index.tsx
|
||||||
|
│ ├── content/
|
||||||
|
│ │ └── index.tsx
|
||||||
|
│ └── modals/
|
||||||
|
│ └── index.tsx
|
||||||
|
└── utils/
|
||||||
|
└── helpers.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern C: Feature-Based Structure (Dify Standard)
|
||||||
|
|
||||||
|
Following Dify's existing patterns:
|
||||||
|
|
||||||
|
```
|
||||||
|
configuration/
|
||||||
|
├── index.tsx # Main page component
|
||||||
|
├── base/ # Base/shared components
|
||||||
|
│ ├── feature-panel/
|
||||||
|
│ ├── group-name/
|
||||||
|
│ └── operation-btn/
|
||||||
|
├── config/ # Config section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ ├── agent/
|
||||||
|
│ └── automatic/
|
||||||
|
├── dataset-config/ # Dataset section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ ├── card-item/
|
||||||
|
│ └── params-config/
|
||||||
|
├── debug/ # Debug section
|
||||||
|
│ ├── index.tsx
|
||||||
|
│ └── hooks.tsx
|
||||||
|
└── hooks/ # Shared hooks
|
||||||
|
└── use-advanced-prompt-config.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
## Props Design
|
||||||
|
|
||||||
|
### Minimal Props Principle
|
||||||
|
|
||||||
|
Pass only what's needed:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Bad: Passing entire objects when only some fields needed
|
||||||
|
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
|
||||||
|
|
||||||
|
// ✅ Good: Destructure to minimum required
|
||||||
|
<ConfigHeader
|
||||||
|
appName={appDetail.name}
|
||||||
|
isAdvancedMode={modelConfig.isAdvanced}
|
||||||
|
onPublish={handlePublish}
|
||||||
|
/>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Callback Props Pattern
|
||||||
|
|
||||||
|
Use callbacks for child-to-parent communication:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Parent
|
||||||
|
const Parent = () => {
|
||||||
|
const [value, setValue] = useState('')
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Child
|
||||||
|
value={value}
|
||||||
|
onChange={setValue}
|
||||||
|
onSubmit={handleSubmit}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Child
|
||||||
|
interface ChildProps {
|
||||||
|
value: string
|
||||||
|
onChange: (value: string) => void
|
||||||
|
onSubmit: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<input value={value} onChange={e => onChange(e.target.value)} />
|
||||||
|
<button onClick={onSubmit}>Submit</button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Render Props for Flexibility
|
||||||
|
|
||||||
|
When sub-components need parent context:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface ListProps<T> {
|
||||||
|
items: T[]
|
||||||
|
renderItem: (item: T, index: number) => React.ReactNode
|
||||||
|
renderEmpty?: () => React.ReactNode
|
||||||
|
}
|
||||||
|
|
||||||
|
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
|
||||||
|
if (items.length === 0 && renderEmpty) {
|
||||||
|
return <>{renderEmpty()}</>
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{items.map((item, index) => renderItem(item, index))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage
|
||||||
|
<List
|
||||||
|
items={operations}
|
||||||
|
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
|
||||||
|
renderEmpty={() => <EmptyState message="No operations" />}
|
||||||
|
/>
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,317 @@
|
||||||
|
# Hook Extraction Patterns
|
||||||
|
|
||||||
|
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
|
||||||
|
|
||||||
|
## When to Extract Hooks
|
||||||
|
|
||||||
|
Extract a custom hook when you identify:
|
||||||
|
|
||||||
|
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
|
||||||
|
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
|
||||||
|
1. **Business logic** - Data transformations, validations, or calculations
|
||||||
|
1. **Reusable patterns** - Logic that appears in multiple components
|
||||||
|
|
||||||
|
## Extraction Process
|
||||||
|
|
||||||
|
### Step 1: Identify State Groups
|
||||||
|
|
||||||
|
Look for state variables that are logically related:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ These belong together - extract to hook
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
|
||||||
|
|
||||||
|
// These are model-related state that should be in useModelConfig()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Identify Related Effects
|
||||||
|
|
||||||
|
Find effects that modify the grouped state:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ These effects belong with the state above
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasFetchedDetail && !modelModeType) {
|
||||||
|
const mode = currModel?.model_properties.mode
|
||||||
|
if (mode) {
|
||||||
|
const newModelConfig = produce(modelConfig, (draft) => {
|
||||||
|
draft.mode = mode
|
||||||
|
})
|
||||||
|
setModelConfig(newModelConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Create the Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// hooks/use-model-config.ts
|
||||||
|
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
|
import type { ModelConfig } from '@/models/debug'
|
||||||
|
import { produce } from 'immer'
|
||||||
|
import { useEffect, useState } from 'react'
|
||||||
|
import { ModelModeType } from '@/types/app'
|
||||||
|
|
||||||
|
interface UseModelConfigParams {
|
||||||
|
initialConfig?: Partial<ModelConfig>
|
||||||
|
currModel?: { model_properties?: { mode?: ModelModeType } }
|
||||||
|
hasFetchedDetail: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UseModelConfigReturn {
|
||||||
|
modelConfig: ModelConfig
|
||||||
|
setModelConfig: (config: ModelConfig) => void
|
||||||
|
completionParams: FormValue
|
||||||
|
setCompletionParams: (params: FormValue) => void
|
||||||
|
modelModeType: ModelModeType
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useModelConfig = ({
|
||||||
|
initialConfig,
|
||||||
|
currModel,
|
||||||
|
hasFetchedDetail,
|
||||||
|
}: UseModelConfigParams): UseModelConfigReturn => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>({
|
||||||
|
provider: 'langgenius/openai/openai',
|
||||||
|
model_id: 'gpt-3.5-turbo',
|
||||||
|
mode: ModelModeType.unset,
|
||||||
|
// ... default values
|
||||||
|
...initialConfig,
|
||||||
|
})
|
||||||
|
|
||||||
|
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||||
|
|
||||||
|
const modelModeType = modelConfig.mode
|
||||||
|
|
||||||
|
// Fill old app data missing model mode
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasFetchedDetail && !modelModeType) {
|
||||||
|
const mode = currModel?.model_properties?.mode
|
||||||
|
if (mode) {
|
||||||
|
setModelConfig(produce(modelConfig, (draft) => {
|
||||||
|
draft.mode = mode
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [hasFetchedDetail, modelModeType, currModel])
|
||||||
|
|
||||||
|
return {
|
||||||
|
modelConfig,
|
||||||
|
setModelConfig,
|
||||||
|
completionParams,
|
||||||
|
setCompletionParams,
|
||||||
|
modelModeType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Update Component
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Before: 50+ lines of state management
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||||
|
// ... lots of related state and effects
|
||||||
|
}
|
||||||
|
|
||||||
|
// After: Clean component
|
||||||
|
const Configuration: FC = () => {
|
||||||
|
const {
|
||||||
|
modelConfig,
|
||||||
|
setModelConfig,
|
||||||
|
completionParams,
|
||||||
|
setCompletionParams,
|
||||||
|
modelModeType,
|
||||||
|
} = useModelConfig({
|
||||||
|
currModel,
|
||||||
|
hasFetchedDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Component now focuses on UI
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Naming Conventions
|
||||||
|
|
||||||
|
### Hook Names
|
||||||
|
|
||||||
|
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
|
||||||
|
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
|
||||||
|
- Include domain: `useWorkflowVariables`, `useMCPServer`
|
||||||
|
|
||||||
|
### File Names
|
||||||
|
|
||||||
|
- Kebab-case: `use-model-config.ts`
|
||||||
|
- Place in `hooks/` subdirectory when multiple hooks exist
|
||||||
|
- Place alongside component for single-use hooks
|
||||||
|
|
||||||
|
### Return Type Names
|
||||||
|
|
||||||
|
- Suffix with `Return`: `UseModelConfigReturn`
|
||||||
|
- Suffix params with `Params`: `UseModelConfigParams`
|
||||||
|
|
||||||
|
## Common Hook Patterns in Dify
|
||||||
|
|
||||||
|
### 1. Data Fetching Hook (React Query)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Use @tanstack/react-query for data fetching
|
||||||
|
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { get } from '@/service/base'
|
||||||
|
import { useInvalid } from '@/service/use-base'
|
||||||
|
|
||||||
|
const NAME_SPACE = 'appConfig'
|
||||||
|
|
||||||
|
// Query keys for cache management
|
||||||
|
export const appConfigQueryKeys = {
|
||||||
|
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main data hook
|
||||||
|
export const useAppConfig = (appId: string) => {
|
||||||
|
return useQuery({
|
||||||
|
enabled: !!appId,
|
||||||
|
queryKey: appConfigQueryKeys.detail(appId),
|
||||||
|
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||||
|
select: data => data?.model_config || null,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidation hook for refreshing data
|
||||||
|
export const useInvalidAppConfig = () => {
|
||||||
|
return useInvalid([NAME_SPACE])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage in component
|
||||||
|
const Component = () => {
|
||||||
|
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
|
||||||
|
const invalidAppConfig = useInvalidAppConfig()
|
||||||
|
|
||||||
|
const handleRefresh = () => {
|
||||||
|
invalidAppConfig() // Invalidates cache and triggers refetch
|
||||||
|
}
|
||||||
|
|
||||||
|
return <div>...</div>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Form State Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Form state + validation + submission
|
||||||
|
export const useConfigForm = (initialValues: ConfigFormValues) => {
|
||||||
|
const [values, setValues] = useState(initialValues)
|
||||||
|
const [errors, setErrors] = useState<Record<string, string>>({})
|
||||||
|
const [isSubmitting, setIsSubmitting] = useState(false)
|
||||||
|
|
||||||
|
const validate = useCallback(() => {
|
||||||
|
const newErrors: Record<string, string> = {}
|
||||||
|
if (!values.name) newErrors.name = 'Name is required'
|
||||||
|
setErrors(newErrors)
|
||||||
|
return Object.keys(newErrors).length === 0
|
||||||
|
}, [values])
|
||||||
|
|
||||||
|
const handleChange = useCallback((field: string, value: any) => {
|
||||||
|
setValues(prev => ({ ...prev, [field]: value }))
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
|
||||||
|
if (!validate()) return
|
||||||
|
setIsSubmitting(true)
|
||||||
|
try {
|
||||||
|
await onSubmit(values)
|
||||||
|
} finally {
|
||||||
|
setIsSubmitting(false)
|
||||||
|
}
|
||||||
|
}, [values, validate])
|
||||||
|
|
||||||
|
return { values, errors, isSubmitting, handleChange, handleSubmit }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Modal State Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Multiple modal management
|
||||||
|
type ModalType = 'edit' | 'delete' | 'duplicate' | null
|
||||||
|
|
||||||
|
export const useModalState = () => {
|
||||||
|
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||||
|
const [modalData, setModalData] = useState<any>(null)
|
||||||
|
|
||||||
|
const openModal = useCallback((type: ModalType, data?: any) => {
|
||||||
|
setActiveModal(type)
|
||||||
|
setModalData(data)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const closeModal = useCallback(() => {
|
||||||
|
setActiveModal(null)
|
||||||
|
setModalData(null)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
activeModal,
|
||||||
|
modalData,
|
||||||
|
openModal,
|
||||||
|
closeModal,
|
||||||
|
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Toggle/Boolean Hook
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Pattern: Boolean state with convenience methods
|
||||||
|
export const useToggle = (initialValue = false) => {
|
||||||
|
const [value, setValue] = useState(initialValue)
|
||||||
|
|
||||||
|
const toggle = useCallback(() => setValue(v => !v), [])
|
||||||
|
const setTrue = useCallback(() => setValue(true), [])
|
||||||
|
const setFalse = useCallback(() => setValue(false), [])
|
||||||
|
|
||||||
|
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage
|
||||||
|
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Extracted Hooks
|
||||||
|
|
||||||
|
After extraction, test hooks in isolation:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// use-model-config.spec.ts
|
||||||
|
import { renderHook, act } from '@testing-library/react'
|
||||||
|
import { useModelConfig } from './use-model-config'
|
||||||
|
|
||||||
|
describe('useModelConfig', () => {
|
||||||
|
it('should initialize with default values', () => {
|
||||||
|
const { result } = renderHook(() => useModelConfig({
|
||||||
|
hasFetchedDetail: false,
|
||||||
|
}))
|
||||||
|
|
||||||
|
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
|
||||||
|
expect(result.current.modelModeType).toBe(ModelModeType.unset)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update model config', () => {
|
||||||
|
const { result } = renderHook(() => useModelConfig({
|
||||||
|
hasFetchedDetail: true,
|
||||||
|
}))
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
result.current.setModelConfig({
|
||||||
|
...result.current.modelConfig,
|
||||||
|
model_id: 'gpt-4',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.current.modelConfig.model_id).toBe('gpt-4')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,322 @@
|
||||||
|
---
|
||||||
|
name: frontend-testing
|
||||||
|
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Dify Frontend Testing Skill
|
||||||
|
|
||||||
|
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||||
|
|
||||||
|
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||||
|
|
||||||
|
## When to Apply This Skill
|
||||||
|
|
||||||
|
Apply this skill when the user:
|
||||||
|
|
||||||
|
- Asks to **write tests** for a component, hook, or utility
|
||||||
|
- Asks to **review existing tests** for completeness
|
||||||
|
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
|
||||||
|
- Requests **test coverage** improvement
|
||||||
|
- Uses `pnpm analyze-component` output as context
|
||||||
|
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
||||||
|
- Wants to understand **testing patterns** in the Dify codebase
|
||||||
|
|
||||||
|
**Do NOT apply** when:
|
||||||
|
|
||||||
|
- User is asking about backend/API tests (Python/pytest)
|
||||||
|
- User is asking about E2E tests (Playwright/Cypress)
|
||||||
|
- User is only asking conceptual questions without code context
|
||||||
|
|
||||||
|
## Quick Reference
|
||||||
|
|
||||||
|
### Tech Stack
|
||||||
|
|
||||||
|
| Tool | Version | Purpose |
|
||||||
|
|------|---------|---------|
|
||||||
|
| Vitest | 4.0.16 | Test runner |
|
||||||
|
| React Testing Library | 16.0 | Component testing |
|
||||||
|
| jsdom | - | Test environment |
|
||||||
|
| nock | 14.0 | HTTP mocking |
|
||||||
|
| TypeScript | 5.x | Type safety |
|
||||||
|
|
||||||
|
### Key Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
pnpm test
|
||||||
|
|
||||||
|
# Watch mode
|
||||||
|
pnpm test:watch
|
||||||
|
|
||||||
|
# Run specific file
|
||||||
|
pnpm test path/to/file.spec.tsx
|
||||||
|
|
||||||
|
# Generate coverage report
|
||||||
|
pnpm test:coverage
|
||||||
|
|
||||||
|
# Analyze component complexity
|
||||||
|
pnpm analyze-component <path>
|
||||||
|
|
||||||
|
# Review existing test
|
||||||
|
pnpm analyze-component <path> --review
|
||||||
|
```
|
||||||
|
|
||||||
|
### File Naming
|
||||||
|
|
||||||
|
- Test files: `ComponentName.spec.tsx` (same directory as component)
|
||||||
|
- Integration tests: `web/__tests__/` directory
|
||||||
|
|
||||||
|
## Test Structure Template
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
|
import Component from './index'
|
||||||
|
|
||||||
|
// ✅ Import real project components (DO NOT mock these)
|
||||||
|
// import Loading from '@/app/components/base/loading'
|
||||||
|
// import { ChildComponent } from './child-component'
|
||||||
|
|
||||||
|
// ✅ Mock external dependencies only
|
||||||
|
vi.mock('@/service/api')
|
||||||
|
vi.mock('next/navigation', () => ({
|
||||||
|
useRouter: () => ({ push: vi.fn() }),
|
||||||
|
usePathname: () => '/test',
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Shared state for mocks (if needed)
|
||||||
|
let mockSharedState = false
|
||||||
|
|
||||||
|
describe('ComponentName', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
||||||
|
mockSharedState = false // ✅ Reset shared state
|
||||||
|
})
|
||||||
|
|
||||||
|
// Rendering tests (REQUIRED)
|
||||||
|
describe('Rendering', () => {
|
||||||
|
it('should render without crashing', () => {
|
||||||
|
// Arrange
|
||||||
|
const props = { title: 'Test' }
|
||||||
|
|
||||||
|
// Act
|
||||||
|
render(<Component {...props} />)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(screen.getByText('Test')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Props tests (REQUIRED)
|
||||||
|
describe('Props', () => {
|
||||||
|
it('should apply custom className', () => {
|
||||||
|
render(<Component className="custom" />)
|
||||||
|
expect(screen.getByRole('button')).toHaveClass('custom')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// User Interactions
|
||||||
|
describe('User Interactions', () => {
|
||||||
|
it('should handle click events', () => {
|
||||||
|
const handleClick = vi.fn()
|
||||||
|
render(<Component onClick={handleClick} />)
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByRole('button'))
|
||||||
|
|
||||||
|
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Edge Cases (REQUIRED)
|
||||||
|
describe('Edge Cases', () => {
|
||||||
|
it('should handle null data', () => {
|
||||||
|
render(<Component data={null} />)
|
||||||
|
expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty array', () => {
|
||||||
|
render(<Component items={[]} />)
|
||||||
|
expect(screen.getByText(/empty/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Workflow (CRITICAL)
|
||||||
|
|
||||||
|
### ⚠️ Incremental Approach Required
|
||||||
|
|
||||||
|
**NEVER generate all test files at once.** For complex components or multi-file directories:
|
||||||
|
|
||||||
|
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
|
||||||
|
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
|
||||||
|
1. **Verify before proceeding**: Do NOT continue to next file until current passes
|
||||||
|
|
||||||
|
```
|
||||||
|
For each file:
|
||||||
|
┌────────────────────────────────────────┐
|
||||||
|
│ 1. Write test │
|
||||||
|
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||||
|
│ 3. PASS? → Mark complete, next file │
|
||||||
|
│ FAIL? → Fix first, then continue │
|
||||||
|
└────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complexity-Based Order
|
||||||
|
|
||||||
|
Process in this order for multi-file testing:
|
||||||
|
|
||||||
|
1. 🟢 Utility functions (simplest)
|
||||||
|
1. 🟢 Custom hooks
|
||||||
|
1. 🟡 Simple components (presentational)
|
||||||
|
1. 🟡 Medium components (state, effects)
|
||||||
|
1. 🔴 Complex components (API, routing)
|
||||||
|
1. 🔴 Integration tests (index files - last)
|
||||||
|
|
||||||
|
### When to Refactor First
|
||||||
|
|
||||||
|
- **Complexity > 50**: Break into smaller pieces before testing
|
||||||
|
- **500+ lines**: Consider splitting before testing
|
||||||
|
- **Many dependencies**: Extract logic into hooks first
|
||||||
|
|
||||||
|
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
### Path-Level Testing (Directory Testing)
|
||||||
|
|
||||||
|
When assigned to test a directory/path, test **ALL content** within that path:
|
||||||
|
|
||||||
|
- Test all components, hooks, utilities in the directory (not just `index` file)
|
||||||
|
- Use incremental approach: one file at a time, verify each before proceeding
|
||||||
|
- Goal: 100% coverage of ALL files in the directory
|
||||||
|
|
||||||
|
### Integration Testing First
|
||||||
|
|
||||||
|
**Prefer integration testing** when writing tests for a directory:
|
||||||
|
|
||||||
|
- ✅ **Import real project components** directly (including base components and siblings)
|
||||||
|
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
|
||||||
|
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
|
||||||
|
- ❌ **DO NOT mock** sibling/child components in the same directory
|
||||||
|
|
||||||
|
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
|
||||||
|
|
||||||
|
## Core Principles
|
||||||
|
|
||||||
|
### 1. AAA Pattern (Arrange-Act-Assert)
|
||||||
|
|
||||||
|
Every test should clearly separate:
|
||||||
|
|
||||||
|
- **Arrange**: Setup test data and render component
|
||||||
|
- **Act**: Perform user actions
|
||||||
|
- **Assert**: Verify expected outcomes
|
||||||
|
|
||||||
|
### 2. Black-Box Testing
|
||||||
|
|
||||||
|
- Test observable behavior, not implementation details
|
||||||
|
- Use semantic queries (getByRole, getByLabelText)
|
||||||
|
- Avoid testing internal state directly
|
||||||
|
- **Prefer pattern matching over hardcoded strings** in assertions:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Avoid: hardcoded text assertions
|
||||||
|
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||||
|
|
||||||
|
// ✅ Better: role-based queries
|
||||||
|
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||||
|
|
||||||
|
// ✅ Better: pattern matching
|
||||||
|
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Single Behavior Per Test
|
||||||
|
|
||||||
|
Each test verifies ONE user-observable behavior:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ✅ Good: One behavior
|
||||||
|
it('should disable button when loading', () => {
|
||||||
|
render(<Button loading />)
|
||||||
|
expect(screen.getByRole('button')).toBeDisabled()
|
||||||
|
})
|
||||||
|
|
||||||
|
// ❌ Bad: Multiple behaviors
|
||||||
|
it('should handle loading state', () => {
|
||||||
|
render(<Button loading />)
|
||||||
|
expect(screen.getByRole('button')).toBeDisabled()
|
||||||
|
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||||
|
expect(screen.getByRole('button')).toHaveClass('loading')
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Semantic Naming
|
||||||
|
|
||||||
|
Use `should <behavior> when <condition>`:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
it('should show error message when validation fails')
|
||||||
|
it('should call onSubmit when form is valid')
|
||||||
|
it('should disable input when isReadOnly is true')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Required Test Scenarios
|
||||||
|
|
||||||
|
### Always Required (All Components)
|
||||||
|
|
||||||
|
1. **Rendering**: Component renders without crashing
|
||||||
|
1. **Props**: Required props, optional props, default values
|
||||||
|
1. **Edge Cases**: null, undefined, empty values, boundary conditions
|
||||||
|
|
||||||
|
### Conditional (When Present)
|
||||||
|
|
||||||
|
| Feature | Test Focus |
|
||||||
|
|---------|-----------|
|
||||||
|
| `useState` | Initial state, transitions, cleanup |
|
||||||
|
| `useEffect` | Execution, dependencies, cleanup |
|
||||||
|
| Event handlers | All onClick, onChange, onSubmit, keyboard |
|
||||||
|
| API calls | Loading, success, error states |
|
||||||
|
| Routing | Navigation, params, query strings |
|
||||||
|
| `useCallback`/`useMemo` | Referential equality |
|
||||||
|
| Context | Provider values, consumer behavior |
|
||||||
|
| Forms | Validation, submission, error display |
|
||||||
|
|
||||||
|
## Coverage Goals (Per File)
|
||||||
|
|
||||||
|
For each test file generated, aim for:
|
||||||
|
|
||||||
|
- ✅ **100%** function coverage
|
||||||
|
- ✅ **100%** statement coverage
|
||||||
|
- ✅ **>95%** branch coverage
|
||||||
|
- ✅ **>95%** line coverage
|
||||||
|
|
||||||
|
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
|
||||||
|
|
||||||
|
## Detailed Guides
|
||||||
|
|
||||||
|
For more detailed information, refer to:
|
||||||
|
|
||||||
|
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||||
|
- `references/mocking.md` - Mock patterns and best practices
|
||||||
|
- `references/async-testing.md` - Async operations and API calls
|
||||||
|
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||||
|
- `references/common-patterns.md` - Frequently used testing patterns
|
||||||
|
- `references/checklist.md` - Test generation checklist and validation steps
|
||||||
|
|
||||||
|
## Authoritative References
|
||||||
|
|
||||||
|
### Primary Specification (MUST follow)
|
||||||
|
|
||||||
|
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
|
||||||
|
|
||||||
|
### Reference Examples in Codebase
|
||||||
|
|
||||||
|
- `web/utils/classnames.spec.ts` - Utility function tests
|
||||||
|
- `web/app/components/base/button/index.spec.tsx` - Component tests
|
||||||
|
- `web/__mocks__/provider-context.ts` - Mock factory example
|
||||||
|
|
||||||
|
### Project Configuration
|
||||||
|
|
||||||
|
- `web/vitest.config.ts` - Vitest configuration
|
||||||
|
- `web/vitest.setup.ts` - Test environment setup
|
||||||
|
- `web/scripts/analyze-component.js` - Component analysis tool
|
||||||
|
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||||
|
|
@ -0,0 +1,296 @@
|
||||||
|
/**
|
||||||
|
* Test Template for React Components
|
||||||
|
*
|
||||||
|
* WHY THIS STRUCTURE?
|
||||||
|
* - Organized sections make tests easy to navigate and maintain
|
||||||
|
* - Mocks at top ensure consistent test isolation
|
||||||
|
* - Factory functions reduce duplication and improve readability
|
||||||
|
* - describe blocks group related scenarios for better debugging
|
||||||
|
*
|
||||||
|
* INSTRUCTIONS:
|
||||||
|
* 1. Replace `ComponentName` with your component name
|
||||||
|
* 2. Update import path
|
||||||
|
* 3. Add/remove test sections based on component features (use analyze-component)
|
||||||
|
* 4. Follow AAA pattern: Arrange → Act → Assert
|
||||||
|
*
|
||||||
|
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
// import ComponentName from './index'
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Mocks
|
||||||
|
// ============================================================================
|
||||||
|
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
|
||||||
|
// They run BEFORE imports, so keep them before component imports.
|
||||||
|
|
||||||
|
// i18n (automatically mocked)
|
||||||
|
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||||
|
// No explicit mock needed - it returns translation keys as-is
|
||||||
|
// Override only if custom translations are required:
|
||||||
|
// vi.mock('react-i18next', () => ({
|
||||||
|
// useTranslation: () => ({
|
||||||
|
// t: (key: string) => {
|
||||||
|
// const customTranslations: Record<string, string> = {
|
||||||
|
// 'my.custom.key': 'Custom Translation',
|
||||||
|
// }
|
||||||
|
// return customTranslations[key] || key
|
||||||
|
// },
|
||||||
|
// }),
|
||||||
|
// }))
|
||||||
|
|
||||||
|
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||||
|
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||||
|
// const mockPush = vi.fn()
|
||||||
|
// vi.mock('next/navigation', () => ({
|
||||||
|
// useRouter: () => ({ push: mockPush }),
|
||||||
|
// usePathname: () => '/test-path',
|
||||||
|
// }))
|
||||||
|
|
||||||
|
// API services (if component fetches data)
|
||||||
|
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
||||||
|
// vi.mock('@/service/api')
|
||||||
|
// import * as api from '@/service/api'
|
||||||
|
// const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
|
// Shared mock state (for portal/dropdown components)
|
||||||
|
// WHY: Portal components like PortalToFollowElem need shared state between
|
||||||
|
// parent and child mocks to correctly simulate open/close behavior
|
||||||
|
// let mockOpenState = false
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Test Data Factories
|
||||||
|
// ============================================================================
|
||||||
|
// WHY FACTORIES?
|
||||||
|
// - Avoid hard-coded test data scattered across tests
|
||||||
|
// - Easy to create variations with overrides
|
||||||
|
// - Type-safe when using actual types from source
|
||||||
|
// - Single source of truth for default test values
|
||||||
|
|
||||||
|
// const createMockProps = (overrides = {}) => ({
|
||||||
|
// // Default props that make component render successfully
|
||||||
|
// ...overrides,
|
||||||
|
// })
|
||||||
|
|
||||||
|
// const createMockItem = (overrides = {}) => ({
|
||||||
|
// id: 'item-1',
|
||||||
|
// name: 'Test Item',
|
||||||
|
// ...overrides,
|
||||||
|
// })
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Test Helpers
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// const renderComponent = (props = {}) => {
|
||||||
|
// return render(<ComponentName {...createMockProps(props)} />)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Tests
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
describe('ComponentName', () => {
|
||||||
|
// WHY beforeEach with clearAllMocks?
|
||||||
|
// - Ensures each test starts with clean slate
|
||||||
|
// - Prevents mock call history from leaking between tests
|
||||||
|
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
||||||
|
// mockOpenState = false
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Rendering Tests (REQUIRED - Every component MUST have these)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// WHY: Catches import errors, missing providers, and basic render issues
|
||||||
|
describe('Rendering', () => {
|
||||||
|
it('should render without crashing', () => {
|
||||||
|
// Arrange - Setup data and mocks
|
||||||
|
// const props = createMockProps()
|
||||||
|
|
||||||
|
// Act - Render the component
|
||||||
|
// render(<ComponentName {...props} />)
|
||||||
|
|
||||||
|
// Assert - Verify expected output
|
||||||
|
// Prefer getByRole for accessibility; it's what users "see"
|
||||||
|
// expect(screen.getByRole('...')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should render with default props', () => {
|
||||||
|
// WHY: Verifies component works without optional props
|
||||||
|
// render(<ComponentName />)
|
||||||
|
// expect(screen.getByText('...')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Props Tests (REQUIRED - Every component MUST test prop behavior)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// WHY: Props are the component's API contract. Test them thoroughly.
|
||||||
|
describe('Props', () => {
|
||||||
|
it('should apply custom className', () => {
|
||||||
|
// WHY: Common pattern in Dify - components should merge custom classes
|
||||||
|
// render(<ComponentName className="custom-class" />)
|
||||||
|
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should use default values for optional props', () => {
|
||||||
|
// WHY: Verifies TypeScript defaults work at runtime
|
||||||
|
// render(<ComponentName />)
|
||||||
|
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// User Interactions (if component has event handlers - on*, handle*)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// WHY: Event handlers are core functionality. Test from user's perspective.
|
||||||
|
describe('User Interactions', () => {
|
||||||
|
it('should call onClick when clicked', async () => {
|
||||||
|
// WHY userEvent over fireEvent?
|
||||||
|
// - userEvent simulates real user behavior (focus, hover, then click)
|
||||||
|
// - fireEvent is lower-level, doesn't trigger all browser events
|
||||||
|
// const user = userEvent.setup()
|
||||||
|
// const handleClick = vi.fn()
|
||||||
|
// render(<ComponentName onClick={handleClick} />)
|
||||||
|
//
|
||||||
|
// await user.click(screen.getByRole('button'))
|
||||||
|
//
|
||||||
|
// expect(handleClick).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should call onChange when value changes', async () => {
|
||||||
|
// const user = userEvent.setup()
|
||||||
|
// const handleChange = vi.fn()
|
||||||
|
// render(<ComponentName onChange={handleChange} />)
|
||||||
|
//
|
||||||
|
// await user.type(screen.getByRole('textbox'), 'new value')
|
||||||
|
//
|
||||||
|
// expect(handleChange).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// State Management (if component uses useState/useReducer)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// WHY: Test state through observable UI changes, not internal state values
|
||||||
|
describe('State Management', () => {
|
||||||
|
it('should update state on interaction', async () => {
|
||||||
|
// WHY test via UI, not state?
|
||||||
|
// - State is implementation detail; UI is what users see
|
||||||
|
// - If UI works correctly, state must be correct
|
||||||
|
// const user = userEvent.setup()
|
||||||
|
// render(<ComponentName />)
|
||||||
|
//
|
||||||
|
// // Initial state - verify what user sees
|
||||||
|
// expect(screen.getByText('Initial')).toBeInTheDocument()
|
||||||
|
//
|
||||||
|
// // Trigger state change via user action
|
||||||
|
// await user.click(screen.getByRole('button'))
|
||||||
|
//
|
||||||
|
// // New state - verify UI updated
|
||||||
|
// expect(screen.getByText('Updated')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Async Operations (if component fetches data - useQuery, fetch)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// WHY: Async operations have 3 states users experience: loading, success, error
|
||||||
|
describe('Async Operations', () => {
|
||||||
|
it('should show loading state', () => {
|
||||||
|
// WHY never-resolving promise?
|
||||||
|
// - Keeps component in loading state for assertion
|
||||||
|
// - Alternative: use fake timers
|
||||||
|
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
|
||||||
|
// render(<ComponentName />)
|
||||||
|
//
|
||||||
|
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show data on success', async () => {
|
||||||
|
// WHY waitFor?
|
||||||
|
// - Component updates asynchronously after fetch resolves
|
||||||
|
// - waitFor retries assertion until it passes or times out
|
||||||
|
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
|
||||||
|
// render(<ComponentName />)
|
||||||
|
//
|
||||||
|
// await waitFor(() => {
|
||||||
|
// expect(screen.getByText('Item 1')).toBeInTheDocument()
|
||||||
|
// })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show error on failure', async () => {
|
||||||
|
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||||
|
// render(<ComponentName />)
|
||||||
|
//
|
||||||
|
// await waitFor(() => {
|
||||||
|
// expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||||
|
// })
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// WHY: Real-world data is messy. Components must handle:
|
||||||
|
// - Null/undefined from API failures or optional fields
|
||||||
|
// - Empty arrays/strings from user clearing data
|
||||||
|
// - Boundary values (0, MAX_INT, special characters)
|
||||||
|
describe('Edge Cases', () => {
|
||||||
|
it('should handle null value', () => {
|
||||||
|
// WHY test null specifically?
|
||||||
|
// - API might return null for missing data
|
||||||
|
// - Prevents "Cannot read property of null" in production
|
||||||
|
// render(<ComponentName value={null} />)
|
||||||
|
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle undefined value', () => {
|
||||||
|
// WHY test undefined separately from null?
|
||||||
|
// - TypeScript treats them differently
|
||||||
|
// - Optional props are undefined, not null
|
||||||
|
// render(<ComponentName value={undefined} />)
|
||||||
|
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty array', () => {
|
||||||
|
// WHY: Empty state often needs special UI (e.g., "No items yet")
|
||||||
|
// render(<ComponentName items={[]} />)
|
||||||
|
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty string', () => {
|
||||||
|
// WHY: Empty strings are truthy in JS but visually empty
|
||||||
|
// render(<ComponentName text="" />)
|
||||||
|
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Accessibility (optional but recommended for Dify's enterprise users)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// WHY: Dify has enterprise customers who may require accessibility compliance
|
||||||
|
describe('Accessibility', () => {
|
||||||
|
it('should have accessible name', () => {
|
||||||
|
// WHY getByRole with name?
|
||||||
|
// - Tests that screen readers can identify the element
|
||||||
|
// - Enforces proper labeling practices
|
||||||
|
// render(<ComponentName label="Test Label" />)
|
||||||
|
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support keyboard navigation', async () => {
|
||||||
|
// WHY: Some users can't use a mouse
|
||||||
|
// const user = userEvent.setup()
|
||||||
|
// render(<ComponentName />)
|
||||||
|
//
|
||||||
|
// await user.tab()
|
||||||
|
// expect(screen.getByRole('button')).toHaveFocus()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,207 @@
|
||||||
|
/**
|
||||||
|
* Test Template for Custom Hooks
|
||||||
|
*
|
||||||
|
* Instructions:
|
||||||
|
* 1. Replace `useHookName` with your hook name
|
||||||
|
* 2. Update import path
|
||||||
|
* 3. Add/remove test sections based on hook features
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||||
|
// import { useHookName } from './use-hook-name'
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Mocks
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// API services (if hook fetches data)
|
||||||
|
// vi.mock('@/service/api')
|
||||||
|
// import * as api from '@/service/api'
|
||||||
|
// const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Test Helpers
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// Wrapper for hooks that need context
|
||||||
|
// const createWrapper = (contextValue = {}) => {
|
||||||
|
// return ({ children }: { children: React.ReactNode }) => (
|
||||||
|
// <SomeContext.Provider value={contextValue}>
|
||||||
|
// {children}
|
||||||
|
// </SomeContext.Provider>
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Tests
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
describe('useHookName', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Initial State
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Initial State', () => {
|
||||||
|
it('should return initial state', () => {
|
||||||
|
// const { result } = renderHook(() => useHookName())
|
||||||
|
//
|
||||||
|
// expect(result.current.value).toBe(initialValue)
|
||||||
|
// expect(result.current.isLoading).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should accept initial value from props', () => {
|
||||||
|
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
|
||||||
|
//
|
||||||
|
// expect(result.current.value).toBe('custom')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// State Updates
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('State Updates', () => {
|
||||||
|
it('should update value when setValue is called', () => {
|
||||||
|
// const { result } = renderHook(() => useHookName())
|
||||||
|
//
|
||||||
|
// act(() => {
|
||||||
|
// result.current.setValue('new value')
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// expect(result.current.value).toBe('new value')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should reset to initial value', () => {
|
||||||
|
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
|
||||||
|
//
|
||||||
|
// act(() => {
|
||||||
|
// result.current.setValue('changed')
|
||||||
|
// })
|
||||||
|
// expect(result.current.value).toBe('changed')
|
||||||
|
//
|
||||||
|
// act(() => {
|
||||||
|
// result.current.reset()
|
||||||
|
// })
|
||||||
|
// expect(result.current.value).toBe('initial')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Async Operations
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Async Operations', () => {
|
||||||
|
it('should fetch data on mount', async () => {
|
||||||
|
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
|
||||||
|
//
|
||||||
|
// const { result } = renderHook(() => useHookName())
|
||||||
|
//
|
||||||
|
// // Initially loading
|
||||||
|
// expect(result.current.isLoading).toBe(true)
|
||||||
|
//
|
||||||
|
// // Wait for data
|
||||||
|
// await waitFor(() => {
|
||||||
|
// expect(result.current.isLoading).toBe(false)
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// expect(result.current.data).toEqual({ data: 'test' })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle fetch error', async () => {
|
||||||
|
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||||
|
//
|
||||||
|
// const { result } = renderHook(() => useHookName())
|
||||||
|
//
|
||||||
|
// await waitFor(() => {
|
||||||
|
// expect(result.current.error).toBeTruthy()
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// expect(result.current.error?.message).toBe('Network error')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should refetch when dependency changes', async () => {
|
||||||
|
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
|
||||||
|
//
|
||||||
|
// const { result, rerender } = renderHook(
|
||||||
|
// ({ id }) => useHookName(id),
|
||||||
|
// { initialProps: { id: '1' } }
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// await waitFor(() => {
|
||||||
|
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// rerender({ id: '2' })
|
||||||
|
//
|
||||||
|
// await waitFor(() => {
|
||||||
|
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
|
||||||
|
// })
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Side Effects
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Side Effects', () => {
|
||||||
|
it('should call callback when value changes', () => {
|
||||||
|
// const callback = vi.fn()
|
||||||
|
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
||||||
|
//
|
||||||
|
// act(() => {
|
||||||
|
// result.current.setValue('new value')
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// expect(callback).toHaveBeenCalledWith('new value')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should cleanup on unmount', () => {
|
||||||
|
// const cleanup = vi.fn()
|
||||||
|
// vi.spyOn(window, 'addEventListener')
|
||||||
|
// vi.spyOn(window, 'removeEventListener')
|
||||||
|
//
|
||||||
|
// const { unmount } = renderHook(() => useHookName())
|
||||||
|
//
|
||||||
|
// expect(window.addEventListener).toHaveBeenCalled()
|
||||||
|
//
|
||||||
|
// unmount()
|
||||||
|
//
|
||||||
|
// expect(window.removeEventListener).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Edge Cases
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Edge Cases', () => {
|
||||||
|
it('should handle null input', () => {
|
||||||
|
// const { result } = renderHook(() => useHookName(null))
|
||||||
|
//
|
||||||
|
// expect(result.current.value).toBeNull()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle rapid updates', () => {
|
||||||
|
// const { result } = renderHook(() => useHookName())
|
||||||
|
//
|
||||||
|
// act(() => {
|
||||||
|
// result.current.setValue('1')
|
||||||
|
// result.current.setValue('2')
|
||||||
|
// result.current.setValue('3')
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// expect(result.current.value).toBe('3')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// With Context (if hook uses context)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('With Context', () => {
|
||||||
|
it('should use context value', () => {
|
||||||
|
// const wrapper = createWrapper({ someValue: 'context-value' })
|
||||||
|
// const { result } = renderHook(() => useHookName(), { wrapper })
|
||||||
|
//
|
||||||
|
// expect(result.current.contextValue).toBe('context-value')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,154 @@
|
||||||
|
/**
|
||||||
|
* Test Template for Utility Functions
|
||||||
|
*
|
||||||
|
* Instructions:
|
||||||
|
* 1. Replace `utilityFunction` with your function name
|
||||||
|
* 2. Update import path
|
||||||
|
* 3. Use test.each for data-driven tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
// import { utilityFunction } from './utility'
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Tests
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
describe('utilityFunction', () => {
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Basic Functionality
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Basic Functionality', () => {
|
||||||
|
it('should return expected result for valid input', () => {
|
||||||
|
// expect(utilityFunction('input')).toBe('expected-output')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle multiple arguments', () => {
|
||||||
|
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Data-Driven Tests
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Input/Output Mapping', () => {
|
||||||
|
test.each([
|
||||||
|
// [input, expected]
|
||||||
|
['input1', 'output1'],
|
||||||
|
['input2', 'output2'],
|
||||||
|
['input3', 'output3'],
|
||||||
|
])('should return %s for input %s', (input, expected) => {
|
||||||
|
// expect(utilityFunction(input)).toBe(expected)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Edge Cases
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Edge Cases', () => {
|
||||||
|
it('should handle empty string', () => {
|
||||||
|
// expect(utilityFunction('')).toBe('')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle null', () => {
|
||||||
|
// expect(utilityFunction(null)).toBe(null)
|
||||||
|
// or
|
||||||
|
// expect(() => utilityFunction(null)).toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle undefined', () => {
|
||||||
|
// expect(utilityFunction(undefined)).toBe(undefined)
|
||||||
|
// or
|
||||||
|
// expect(() => utilityFunction(undefined)).toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty array', () => {
|
||||||
|
// expect(utilityFunction([])).toEqual([])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty object', () => {
|
||||||
|
// expect(utilityFunction({})).toEqual({})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Boundary Conditions
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Boundary Conditions', () => {
|
||||||
|
it('should handle minimum value', () => {
|
||||||
|
// expect(utilityFunction(0)).toBe(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle maximum value', () => {
|
||||||
|
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle negative numbers', () => {
|
||||||
|
// expect(utilityFunction(-1)).toBe(...)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Type Coercion (if applicable)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Type Handling', () => {
|
||||||
|
it('should handle numeric string', () => {
|
||||||
|
// expect(utilityFunction('123')).toBe(123)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle boolean', () => {
|
||||||
|
// expect(utilityFunction(true)).toBe(...)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Error Cases
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Error Handling', () => {
|
||||||
|
it('should throw for invalid input', () => {
|
||||||
|
// expect(() => utilityFunction('invalid')).toThrow('Error message')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw with specific error type', () => {
|
||||||
|
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Complex Objects (if applicable)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Object Handling', () => {
|
||||||
|
it('should preserve object structure', () => {
|
||||||
|
// const input = { a: 1, b: 2 }
|
||||||
|
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle nested objects', () => {
|
||||||
|
// const input = { nested: { deep: 'value' } }
|
||||||
|
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not mutate input', () => {
|
||||||
|
// const input = { a: 1 }
|
||||||
|
// const inputCopy = { ...input }
|
||||||
|
// utilityFunction(input)
|
||||||
|
// expect(input).toEqual(inputCopy)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Array Handling (if applicable)
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
describe('Array Handling', () => {
|
||||||
|
it('should process all elements', () => {
|
||||||
|
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle single element array', () => {
|
||||||
|
// expect(utilityFunction([1])).toEqual([2])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should preserve order', () => {
|
||||||
|
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,345 @@
|
||||||
|
# Async Testing Guide
|
||||||
|
|
||||||
|
## Core Async Patterns
|
||||||
|
|
||||||
|
### 1. waitFor - Wait for Condition
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { render, screen, waitFor } from '@testing-library/react'
|
||||||
|
|
||||||
|
it('should load and display data', async () => {
|
||||||
|
render(<DataComponent />)
|
||||||
|
|
||||||
|
// Wait for element to appear
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should hide loading spinner after load', async () => {
|
||||||
|
render(<DataComponent />)
|
||||||
|
|
||||||
|
// Wait for element to disappear
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. findBy\* - Async Queries
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
it('should show user name after fetch', async () => {
|
||||||
|
render(<UserProfile />)
|
||||||
|
|
||||||
|
// findBy returns a promise, auto-waits up to 1000ms
|
||||||
|
const userName = await screen.findByText('John Doe')
|
||||||
|
expect(userName).toBeInTheDocument()
|
||||||
|
|
||||||
|
// findByRole with options
|
||||||
|
const button = await screen.findByRole('button', { name: /submit/i })
|
||||||
|
expect(button).toBeEnabled()
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. userEvent for Async Interactions
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
|
||||||
|
it('should submit form', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const onSubmit = vi.fn()
|
||||||
|
|
||||||
|
render(<Form onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
// userEvent methods are async
|
||||||
|
await user.type(screen.getByLabelText('Email'), 'test@example.com')
|
||||||
|
await user.click(screen.getByRole('button', { name: /submit/i }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fake Timers
|
||||||
|
|
||||||
|
### When to Use Fake Timers
|
||||||
|
|
||||||
|
- Testing components with `setTimeout`/`setInterval`
|
||||||
|
- Testing debounce/throttle behavior
|
||||||
|
- Testing animations or delayed transitions
|
||||||
|
- Testing polling or retry logic
|
||||||
|
|
||||||
|
### Basic Fake Timer Setup
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('Debounced Search', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.useFakeTimers()
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.useRealTimers()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should debounce search input', async () => {
|
||||||
|
const onSearch = vi.fn()
|
||||||
|
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
||||||
|
|
||||||
|
// Type in the input
|
||||||
|
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
|
||||||
|
|
||||||
|
// Search not called immediately
|
||||||
|
expect(onSearch).not.toHaveBeenCalled()
|
||||||
|
|
||||||
|
// Advance timers
|
||||||
|
vi.advanceTimersByTime(300)
|
||||||
|
|
||||||
|
// Now search is called
|
||||||
|
expect(onSearch).toHaveBeenCalledWith('query')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fake Timers with Async Code
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
it('should retry on failure', async () => {
|
||||||
|
vi.useFakeTimers()
|
||||||
|
const fetchData = vi.fn()
|
||||||
|
.mockRejectedValueOnce(new Error('Network error'))
|
||||||
|
.mockResolvedValueOnce({ data: 'success' })
|
||||||
|
|
||||||
|
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
|
||||||
|
|
||||||
|
// First call fails
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(fetchData).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Advance timer for retry
|
||||||
|
vi.advanceTimersByTime(1000)
|
||||||
|
|
||||||
|
// Second call succeeds
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(fetchData).toHaveBeenCalledTimes(2)
|
||||||
|
expect(screen.getByText('success')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.useRealTimers()
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Fake Timer Utilities
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Run all pending timers
|
||||||
|
vi.runAllTimers()
|
||||||
|
|
||||||
|
// Run only pending timers (not new ones created during execution)
|
||||||
|
vi.runOnlyPendingTimers()
|
||||||
|
|
||||||
|
// Advance by specific time
|
||||||
|
vi.advanceTimersByTime(1000)
|
||||||
|
|
||||||
|
// Get current fake time
|
||||||
|
Date.now()
|
||||||
|
|
||||||
|
// Clear all timers
|
||||||
|
vi.clearAllTimers()
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Testing Patterns
|
||||||
|
|
||||||
|
### Loading → Success → Error States
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('DataFetcher', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show loading state', () => {
|
||||||
|
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
|
||||||
|
|
||||||
|
render(<DataFetcher />)
|
||||||
|
|
||||||
|
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show data on success', async () => {
|
||||||
|
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
|
||||||
|
|
||||||
|
render(<DataFetcher />)
|
||||||
|
|
||||||
|
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
|
||||||
|
const item1 = await screen.findByText('Item 1')
|
||||||
|
const item2 = await screen.findByText('Item 2')
|
||||||
|
expect(item1).toBeInTheDocument()
|
||||||
|
expect(item2).toBeInTheDocument()
|
||||||
|
|
||||||
|
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show error on failure', async () => {
|
||||||
|
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
|
||||||
|
|
||||||
|
render(<DataFetcher />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should retry on error', async () => {
|
||||||
|
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||||
|
|
||||||
|
render(<DataFetcher />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Item 1')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing Mutations
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
it('should submit form and show success', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
|
||||||
|
|
||||||
|
render(<CreateItemForm />)
|
||||||
|
|
||||||
|
await user.type(screen.getByLabelText('Name'), 'New Item')
|
||||||
|
await user.click(screen.getByRole('button', { name: /create/i }))
|
||||||
|
|
||||||
|
// Button should be disabled during submission
|
||||||
|
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## useEffect Testing
|
||||||
|
|
||||||
|
### Testing Effect Execution
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
it('should fetch data on mount', async () => {
|
||||||
|
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||||
|
|
||||||
|
render(<ComponentWithEffect fetchData={fetchData} />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(fetchData).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing Effect Dependencies
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
it('should refetch when id changes', async () => {
|
||||||
|
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||||
|
|
||||||
|
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(fetchData).toHaveBeenCalledWith('1')
|
||||||
|
})
|
||||||
|
|
||||||
|
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(fetchData).toHaveBeenCalledWith('2')
|
||||||
|
expect(fetchData).toHaveBeenCalledTimes(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing Effect Cleanup
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
it('should cleanup subscription on unmount', () => {
|
||||||
|
const subscribe = vi.fn()
|
||||||
|
const unsubscribe = vi.fn()
|
||||||
|
subscribe.mockReturnValue(unsubscribe)
|
||||||
|
|
||||||
|
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
||||||
|
|
||||||
|
expect(subscribe).toHaveBeenCalledTimes(1)
|
||||||
|
|
||||||
|
unmount()
|
||||||
|
|
||||||
|
expect(unsubscribe).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Async Pitfalls
|
||||||
|
|
||||||
|
### ❌ Don't: Forget to await
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Bad - test may pass even if assertion fails
|
||||||
|
it('should load data', () => {
|
||||||
|
render(<Component />)
|
||||||
|
waitFor(() => {
|
||||||
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Good - properly awaited
|
||||||
|
it('should load data', async () => {
|
||||||
|
render(<Component />)
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### ❌ Don't: Use multiple assertions in single waitFor
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Bad - if first assertion fails, won't know about second
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Title')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('Description')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Good - separate waitFor or use findBy
|
||||||
|
const title = await screen.findByText('Title')
|
||||||
|
const description = await screen.findByText('Description')
|
||||||
|
expect(title).toBeInTheDocument()
|
||||||
|
expect(description).toBeInTheDocument()
|
||||||
|
```
|
||||||
|
|
||||||
|
### ❌ Don't: Mix fake timers with real async
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Bad - fake timers don't work well with real Promises
|
||||||
|
vi.useFakeTimers()
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
|
}) // May timeout!
|
||||||
|
|
||||||
|
// Good - use runAllTimers or advanceTimersByTime
|
||||||
|
vi.useFakeTimers()
|
||||||
|
render(<Component />)
|
||||||
|
vi.runAllTimers()
|
||||||
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,205 @@
|
||||||
|
# Test Generation Checklist
|
||||||
|
|
||||||
|
Use this checklist when generating or reviewing tests for Dify frontend components.
|
||||||
|
|
||||||
|
## Pre-Generation
|
||||||
|
|
||||||
|
- [ ] Read the component source code completely
|
||||||
|
- [ ] Identify component type (component, hook, utility, page)
|
||||||
|
- [ ] Run `pnpm analyze-component <path>` if available
|
||||||
|
- [ ] Note complexity score and features detected
|
||||||
|
- [ ] Check for existing tests in the same directory
|
||||||
|
- [ ] **Identify ALL files in the directory** that need testing (not just index)
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
|
||||||
|
|
||||||
|
- [ ] **NEVER generate all tests at once** - process one file at a time
|
||||||
|
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
|
||||||
|
- [ ] Create a todo list to track progress before starting
|
||||||
|
- [ ] For EACH file: write → run test → verify pass → then next
|
||||||
|
- [ ] **DO NOT proceed** to next file until current one passes
|
||||||
|
|
||||||
|
### Path-Level Coverage
|
||||||
|
|
||||||
|
- [ ] **Test ALL files** in the assigned directory/path
|
||||||
|
- [ ] List all components, hooks, utilities that need coverage
|
||||||
|
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
|
||||||
|
|
||||||
|
### Complexity Assessment
|
||||||
|
|
||||||
|
- [ ] Run `pnpm analyze-component <path>` for complexity score
|
||||||
|
- [ ] **Complexity > 50**: Consider refactoring before testing
|
||||||
|
- [ ] **500+ lines**: Consider splitting before testing
|
||||||
|
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
|
||||||
|
|
||||||
|
### Integration vs Mocking
|
||||||
|
|
||||||
|
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||||
|
- [ ] Import real project components instead of mocking
|
||||||
|
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
|
||||||
|
- [ ] Prefer integration testing when using single spec file
|
||||||
|
|
||||||
|
## Required Test Sections
|
||||||
|
|
||||||
|
### All Components MUST Have
|
||||||
|
|
||||||
|
- [ ] **Rendering tests** - Component renders without crashing
|
||||||
|
- [ ] **Props tests** - Required props, optional props, default values
|
||||||
|
- [ ] **Edge cases** - null, undefined, empty values, boundaries
|
||||||
|
|
||||||
|
### Conditional Sections (Add When Feature Present)
|
||||||
|
|
||||||
|
| Feature | Add Tests For |
|
||||||
|
|---------|---------------|
|
||||||
|
| `useState` | Initial state, transitions, cleanup |
|
||||||
|
| `useEffect` | Execution, dependencies, cleanup |
|
||||||
|
| Event handlers | onClick, onChange, onSubmit, keyboard |
|
||||||
|
| API calls | Loading, success, error states |
|
||||||
|
| Routing | Navigation, params, query strings |
|
||||||
|
| `useCallback`/`useMemo` | Referential equality |
|
||||||
|
| Context | Provider values, consumer behavior |
|
||||||
|
| Forms | Validation, submission, error display |
|
||||||
|
|
||||||
|
## Code Quality Checklist
|
||||||
|
|
||||||
|
### Structure
|
||||||
|
|
||||||
|
- [ ] Uses `describe` blocks to group related tests
|
||||||
|
- [ ] Test names follow `should <behavior> when <condition>` pattern
|
||||||
|
- [ ] AAA pattern (Arrange-Act-Assert) is clear
|
||||||
|
- [ ] Comments explain complex test scenarios
|
||||||
|
|
||||||
|
### Mocks
|
||||||
|
|
||||||
|
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||||
|
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||||
|
- [ ] Shared mock state reset in `beforeEach`
|
||||||
|
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
|
||||||
|
- [ ] Router mocks match actual Next.js API
|
||||||
|
- [ ] Mocks reflect actual component conditional behavior
|
||||||
|
- [ ] Only mock: API services, complex context providers, third-party libs
|
||||||
|
|
||||||
|
### Queries
|
||||||
|
|
||||||
|
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
|
||||||
|
- [ ] Use `queryBy*` for absence assertions
|
||||||
|
- [ ] Use `findBy*` for async elements
|
||||||
|
- [ ] `getByTestId` only as last resort
|
||||||
|
|
||||||
|
### Async
|
||||||
|
|
||||||
|
- [ ] All async tests use `async/await`
|
||||||
|
- [ ] `waitFor` wraps async assertions
|
||||||
|
- [ ] Fake timers properly setup/teardown
|
||||||
|
- [ ] No floating promises
|
||||||
|
|
||||||
|
### TypeScript
|
||||||
|
|
||||||
|
- [ ] No `any` types without justification
|
||||||
|
- [ ] Mock data uses actual types from source
|
||||||
|
- [ ] Factory functions have proper return types
|
||||||
|
|
||||||
|
## Coverage Goals (Per File)
|
||||||
|
|
||||||
|
For the current file being tested:
|
||||||
|
|
||||||
|
- [ ] 100% function coverage
|
||||||
|
- [ ] 100% statement coverage
|
||||||
|
- [ ] >95% branch coverage
|
||||||
|
- [ ] >95% line coverage
|
||||||
|
|
||||||
|
## Post-Generation (Per File)
|
||||||
|
|
||||||
|
**Run these checks after EACH test file, not just at the end:**
|
||||||
|
|
||||||
|
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
|
||||||
|
- [ ] Fix any failures immediately
|
||||||
|
- [ ] Mark file as complete in todo list
|
||||||
|
- [ ] Only then proceed to next file
|
||||||
|
|
||||||
|
### After All Files Complete
|
||||||
|
|
||||||
|
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||||
|
- [ ] Check coverage report: `pnpm test:coverage`
|
||||||
|
- [ ] Run `pnpm lint:fix` on all test files
|
||||||
|
- [ ] Run `pnpm type-check:tsgo`
|
||||||
|
|
||||||
|
## Common Issues to Watch
|
||||||
|
|
||||||
|
### False Positives
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Mock doesn't match actual behavior
|
||||||
|
vi.mock('./Component', () => () => <div>Mocked</div>)
|
||||||
|
|
||||||
|
// ✅ Mock matches actual conditional logic
|
||||||
|
vi.mock('./Component', () => ({ isOpen }: any) =>
|
||||||
|
isOpen ? <div>Content</div> : null
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### State Leakage
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Shared state not reset
|
||||||
|
let mockState = false
|
||||||
|
vi.mock('./useHook', () => () => mockState)
|
||||||
|
|
||||||
|
// ✅ Reset in beforeEach
|
||||||
|
beforeEach(() => {
|
||||||
|
mockState = false
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Async Race Conditions
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Not awaited
|
||||||
|
it('loads data', () => {
|
||||||
|
render(<Component />)
|
||||||
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
// ✅ Properly awaited
|
||||||
|
it('loads data', async () => {
|
||||||
|
render(<Component />)
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Missing Edge Cases
|
||||||
|
|
||||||
|
Always test these scenarios:
|
||||||
|
|
||||||
|
- `null` / `undefined` inputs
|
||||||
|
- Empty strings / arrays / objects
|
||||||
|
- Boundary values (0, -1, MAX_INT)
|
||||||
|
- Error states
|
||||||
|
- Loading states
|
||||||
|
- Disabled states
|
||||||
|
|
||||||
|
## Quick Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run specific test
|
||||||
|
pnpm test path/to/file.spec.tsx
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
pnpm test:coverage path/to/file.spec.tsx
|
||||||
|
|
||||||
|
# Watch mode
|
||||||
|
pnpm test:watch path/to/file.spec.tsx
|
||||||
|
|
||||||
|
# Update snapshots (use sparingly)
|
||||||
|
pnpm test -u path/to/file.spec.tsx
|
||||||
|
|
||||||
|
# Analyze component
|
||||||
|
pnpm analyze-component path/to/component.tsx
|
||||||
|
|
||||||
|
# Review existing test
|
||||||
|
pnpm analyze-component path/to/component.tsx --review
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,449 @@
|
||||||
|
# Common Testing Patterns
|
||||||
|
|
||||||
|
## Query Priority
|
||||||
|
|
||||||
|
Use queries in this order (most to least preferred):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 1. getByRole - Most recommended (accessibility)
|
||||||
|
screen.getByRole('button', { name: /submit/i })
|
||||||
|
screen.getByRole('textbox', { name: /email/i })
|
||||||
|
screen.getByRole('heading', { level: 1 })
|
||||||
|
|
||||||
|
// 2. getByLabelText - Form fields
|
||||||
|
screen.getByLabelText('Email address')
|
||||||
|
screen.getByLabelText(/password/i)
|
||||||
|
|
||||||
|
// 3. getByPlaceholderText - When no label
|
||||||
|
screen.getByPlaceholderText('Search...')
|
||||||
|
|
||||||
|
// 4. getByText - Non-interactive elements
|
||||||
|
screen.getByText('Welcome to Dify')
|
||||||
|
screen.getByText(/loading/i)
|
||||||
|
|
||||||
|
// 5. getByDisplayValue - Current input value
|
||||||
|
screen.getByDisplayValue('current value')
|
||||||
|
|
||||||
|
// 6. getByAltText - Images
|
||||||
|
screen.getByAltText('Company logo')
|
||||||
|
|
||||||
|
// 7. getByTitle - Tooltip elements
|
||||||
|
screen.getByTitle('Close')
|
||||||
|
|
||||||
|
// 8. getByTestId - Last resort only!
|
||||||
|
screen.getByTestId('custom-element')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Event Handling Patterns
|
||||||
|
|
||||||
|
### Click Events
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Basic click
|
||||||
|
fireEvent.click(screen.getByRole('button'))
|
||||||
|
|
||||||
|
// With userEvent (preferred for realistic interaction)
|
||||||
|
const user = userEvent.setup()
|
||||||
|
await user.click(screen.getByRole('button'))
|
||||||
|
|
||||||
|
// Double click
|
||||||
|
await user.dblClick(screen.getByRole('button'))
|
||||||
|
|
||||||
|
// Right click
|
||||||
|
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
|
||||||
|
```
|
||||||
|
|
||||||
|
### Form Input
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
// Type in input
|
||||||
|
await user.type(screen.getByRole('textbox'), 'Hello World')
|
||||||
|
|
||||||
|
// Clear and type
|
||||||
|
await user.clear(screen.getByRole('textbox'))
|
||||||
|
await user.type(screen.getByRole('textbox'), 'New value')
|
||||||
|
|
||||||
|
// Select option
|
||||||
|
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
|
||||||
|
|
||||||
|
// Check checkbox
|
||||||
|
await user.click(screen.getByRole('checkbox'))
|
||||||
|
|
||||||
|
// Upload file
|
||||||
|
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||||
|
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Keyboard Events
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
// Press Enter
|
||||||
|
await user.keyboard('{Enter}')
|
||||||
|
|
||||||
|
// Press Escape
|
||||||
|
await user.keyboard('{Escape}')
|
||||||
|
|
||||||
|
// Keyboard shortcut
|
||||||
|
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
|
||||||
|
|
||||||
|
// Tab navigation
|
||||||
|
await user.tab()
|
||||||
|
|
||||||
|
// Arrow keys
|
||||||
|
await user.keyboard('{ArrowDown}')
|
||||||
|
await user.keyboard('{ArrowUp}')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Component State Testing
|
||||||
|
|
||||||
|
### Testing State Transitions
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('Counter', () => {
|
||||||
|
it('should increment count', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
render(<Counter initialCount={0} />)
|
||||||
|
|
||||||
|
// Initial state
|
||||||
|
expect(screen.getByText('Count: 0')).toBeInTheDocument()
|
||||||
|
|
||||||
|
// Trigger transition
|
||||||
|
await user.click(screen.getByRole('button', { name: /increment/i }))
|
||||||
|
|
||||||
|
// New state
|
||||||
|
expect(screen.getByText('Count: 1')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing Controlled Components
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('ControlledInput', () => {
|
||||||
|
it('should call onChange with new value', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const handleChange = vi.fn()
|
||||||
|
|
||||||
|
render(<ControlledInput value="" onChange={handleChange} />)
|
||||||
|
|
||||||
|
await user.type(screen.getByRole('textbox'), 'a')
|
||||||
|
|
||||||
|
expect(handleChange).toHaveBeenCalledWith('a')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should display controlled value', () => {
|
||||||
|
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
|
||||||
|
|
||||||
|
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Conditional Rendering Testing
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('ConditionalComponent', () => {
|
||||||
|
it('should show loading state', () => {
|
||||||
|
render(<DataDisplay isLoading={true} data={null} />)
|
||||||
|
|
||||||
|
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||||
|
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show error state', () => {
|
||||||
|
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
|
||||||
|
|
||||||
|
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show data when loaded', () => {
|
||||||
|
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
|
||||||
|
|
||||||
|
expect(screen.getByText('Test')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show empty state when no data', () => {
|
||||||
|
render(<DataDisplay isLoading={false} data={[]} />)
|
||||||
|
|
||||||
|
expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## List Rendering Testing
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('ItemList', () => {
|
||||||
|
const items = [
|
||||||
|
{ id: '1', name: 'Item 1' },
|
||||||
|
{ id: '2', name: 'Item 2' },
|
||||||
|
{ id: '3', name: 'Item 3' },
|
||||||
|
]
|
||||||
|
|
||||||
|
it('should render all items', () => {
|
||||||
|
render(<ItemList items={items} />)
|
||||||
|
|
||||||
|
expect(screen.getAllByRole('listitem')).toHaveLength(3)
|
||||||
|
items.forEach(item => {
|
||||||
|
expect(screen.getByText(item.name)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle item selection', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const onSelect = vi.fn()
|
||||||
|
|
||||||
|
render(<ItemList items={items} onSelect={onSelect} />)
|
||||||
|
|
||||||
|
await user.click(screen.getByText('Item 2'))
|
||||||
|
|
||||||
|
expect(onSelect).toHaveBeenCalledWith(items[1])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty list', () => {
|
||||||
|
render(<ItemList items={[]} />)
|
||||||
|
|
||||||
|
expect(screen.getByText(/no items/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Modal/Dialog Testing
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('Modal', () => {
|
||||||
|
it('should not render when closed', () => {
|
||||||
|
render(<Modal isOpen={false} onClose={vi.fn()} />)
|
||||||
|
|
||||||
|
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should render when open', () => {
|
||||||
|
render(<Modal isOpen={true} onClose={vi.fn()} />)
|
||||||
|
|
||||||
|
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should call onClose when clicking overlay', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const handleClose = vi.fn()
|
||||||
|
|
||||||
|
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||||
|
|
||||||
|
await user.click(screen.getByTestId('modal-overlay'))
|
||||||
|
|
||||||
|
expect(handleClose).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should call onClose when pressing Escape', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const handleClose = vi.fn()
|
||||||
|
|
||||||
|
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||||
|
|
||||||
|
await user.keyboard('{Escape}')
|
||||||
|
|
||||||
|
expect(handleClose).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should trap focus inside modal', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(
|
||||||
|
<Modal isOpen={true} onClose={vi.fn()}>
|
||||||
|
<button>First</button>
|
||||||
|
<button>Second</button>
|
||||||
|
</Modal>
|
||||||
|
)
|
||||||
|
|
||||||
|
// Focus should cycle within modal
|
||||||
|
await user.tab()
|
||||||
|
expect(screen.getByText('First')).toHaveFocus()
|
||||||
|
|
||||||
|
await user.tab()
|
||||||
|
expect(screen.getByText('Second')).toHaveFocus()
|
||||||
|
|
||||||
|
await user.tab()
|
||||||
|
expect(screen.getByText('First')).toHaveFocus() // Cycles back
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Form Testing
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('LoginForm', () => {
|
||||||
|
it('should submit valid form', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const onSubmit = vi.fn()
|
||||||
|
|
||||||
|
render(<LoginForm onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
|
||||||
|
await user.type(screen.getByLabelText(/password/i), 'password123')
|
||||||
|
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||||
|
|
||||||
|
expect(onSubmit).toHaveBeenCalledWith({
|
||||||
|
email: 'test@example.com',
|
||||||
|
password: 'password123',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show validation errors', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(<LoginForm onSubmit={vi.fn()} />)
|
||||||
|
|
||||||
|
// Submit empty form
|
||||||
|
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||||
|
|
||||||
|
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
|
||||||
|
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate email format', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(<LoginForm onSubmit={vi.fn()} />)
|
||||||
|
|
||||||
|
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
||||||
|
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||||
|
|
||||||
|
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should disable submit button while submitting', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
||||||
|
|
||||||
|
render(<LoginForm onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
|
||||||
|
await user.type(screen.getByLabelText(/password/i), 'password123')
|
||||||
|
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||||
|
|
||||||
|
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Data-Driven Tests with test.each
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('StatusBadge', () => {
|
||||||
|
test.each([
|
||||||
|
['success', 'bg-green-500'],
|
||||||
|
['warning', 'bg-yellow-500'],
|
||||||
|
['error', 'bg-red-500'],
|
||||||
|
['info', 'bg-blue-500'],
|
||||||
|
])('should apply correct class for %s status', (status, expectedClass) => {
|
||||||
|
render(<StatusBadge status={status} />)
|
||||||
|
|
||||||
|
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
|
||||||
|
})
|
||||||
|
|
||||||
|
test.each([
|
||||||
|
{ input: null, expected: 'Unknown' },
|
||||||
|
{ input: undefined, expected: 'Unknown' },
|
||||||
|
{ input: '', expected: 'Unknown' },
|
||||||
|
{ input: 'invalid', expected: 'Unknown' },
|
||||||
|
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
|
||||||
|
render(<StatusBadge status={input} />)
|
||||||
|
|
||||||
|
expect(screen.getByText(expected)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Debugging Tips
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Print entire DOM
|
||||||
|
screen.debug()
|
||||||
|
|
||||||
|
// Print specific element
|
||||||
|
screen.debug(screen.getByRole('button'))
|
||||||
|
|
||||||
|
// Log testing playground URL
|
||||||
|
screen.logTestingPlaygroundURL()
|
||||||
|
|
||||||
|
// Pretty print DOM
|
||||||
|
import { prettyDOM } from '@testing-library/react'
|
||||||
|
console.log(prettyDOM(screen.getByRole('dialog')))
|
||||||
|
|
||||||
|
// Check available roles
|
||||||
|
import { getRoles } from '@testing-library/react'
|
||||||
|
console.log(getRoles(container))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Mistakes to Avoid
|
||||||
|
|
||||||
|
### ❌ Don't Use Implementation Details
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Bad - testing implementation
|
||||||
|
expect(component.state.isOpen).toBe(true)
|
||||||
|
expect(wrapper.find('.internal-class').length).toBe(1)
|
||||||
|
|
||||||
|
// Good - testing behavior
|
||||||
|
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||||
|
```
|
||||||
|
|
||||||
|
### ❌ Don't Forget Cleanup
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Bad - may leak state between tests
|
||||||
|
it('test 1', () => {
|
||||||
|
render(<Component />)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Good - cleanup is automatic with RTL, but reset mocks
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ Bad - hardcoded strings are brittle
|
||||||
|
expect(screen.getByText('Submit Form')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||||
|
|
||||||
|
// ✅ Good - role-based queries (most semantic)
|
||||||
|
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
|
||||||
|
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||||
|
|
||||||
|
// ✅ Good - pattern matching (flexible)
|
||||||
|
expect(screen.getByText(/submit/i)).toBeInTheDocument()
|
||||||
|
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||||
|
|
||||||
|
// ✅ Good - test behavior, not exact UI text
|
||||||
|
expect(screen.getByRole('button')).toBeDisabled()
|
||||||
|
expect(screen.getByRole('alert')).toBeInTheDocument()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why prefer black-box assertions?**
|
||||||
|
|
||||||
|
- Text content may change (i18n, copy updates)
|
||||||
|
- Role-based queries test accessibility
|
||||||
|
- Pattern matching is resilient to minor changes
|
||||||
|
- Tests focus on behavior, not implementation details
|
||||||
|
|
||||||
|
### ❌ Don't Assert on Absence Without Query
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Bad - throws if not found
|
||||||
|
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
|
||||||
|
|
||||||
|
// Good - use queryBy for absence assertions
|
||||||
|
expect(screen.queryByText('Error')).not.toBeInTheDocument()
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,523 @@
|
||||||
|
# Domain-Specific Component Testing
|
||||||
|
|
||||||
|
This guide covers testing patterns for Dify's domain-specific components.
|
||||||
|
|
||||||
|
## Workflow Components (`workflow/`)
|
||||||
|
|
||||||
|
Workflow components handle node configuration, data flow, and graph operations.
|
||||||
|
|
||||||
|
### Key Test Areas
|
||||||
|
|
||||||
|
1. **Node Configuration**
|
||||||
|
1. **Data Validation**
|
||||||
|
1. **Variable Passing**
|
||||||
|
1. **Edge Connections**
|
||||||
|
1. **Error Handling**
|
||||||
|
|
||||||
|
### Example: Node Configuration Panel
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
import NodeConfigPanel from './node-config-panel'
|
||||||
|
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
||||||
|
|
||||||
|
// Mock workflow context
|
||||||
|
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||||
|
useWorkflowStore: () => mockWorkflowStore,
|
||||||
|
useNodesInteractions: () => mockNodesInteractions,
|
||||||
|
}))
|
||||||
|
|
||||||
|
let mockWorkflowStore = {
|
||||||
|
nodes: [],
|
||||||
|
edges: [],
|
||||||
|
updateNode: vi.fn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
let mockNodesInteractions = {
|
||||||
|
handleNodeSelect: vi.fn(),
|
||||||
|
handleNodeDelete: vi.fn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('NodeConfigPanel', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
mockWorkflowStore = {
|
||||||
|
nodes: [],
|
||||||
|
edges: [],
|
||||||
|
updateNode: vi.fn(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Node Configuration', () => {
|
||||||
|
it('should render node type selector', () => {
|
||||||
|
const node = createMockNode({ type: 'llm' })
|
||||||
|
render(<NodeConfigPanel node={node} />)
|
||||||
|
|
||||||
|
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update node config on change', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const node = createMockNode({ type: 'llm' })
|
||||||
|
|
||||||
|
render(<NodeConfigPanel node={node} />)
|
||||||
|
|
||||||
|
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
|
||||||
|
|
||||||
|
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
|
||||||
|
node.id,
|
||||||
|
expect.objectContaining({ model: 'gpt-4' })
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Data Validation', () => {
|
||||||
|
it('should show error for invalid input', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const node = createMockNode({ type: 'code' })
|
||||||
|
|
||||||
|
render(<NodeConfigPanel node={node} />)
|
||||||
|
|
||||||
|
// Enter invalid code
|
||||||
|
const codeInput = screen.getByLabelText(/code/i)
|
||||||
|
await user.clear(codeInput)
|
||||||
|
await user.type(codeInput, 'invalid syntax {{{')
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate required fields', async () => {
|
||||||
|
const node = createMockNode({ type: 'http', data: { url: '' } })
|
||||||
|
|
||||||
|
render(<NodeConfigPanel node={node} />)
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: /save/i }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Variable Passing', () => {
|
||||||
|
it('should display available variables from upstream nodes', () => {
|
||||||
|
const upstreamNode = createMockNode({
|
||||||
|
id: 'node-1',
|
||||||
|
type: 'start',
|
||||||
|
data: { outputs: [{ name: 'user_input', type: 'string' }] },
|
||||||
|
})
|
||||||
|
const currentNode = createMockNode({
|
||||||
|
id: 'node-2',
|
||||||
|
type: 'llm',
|
||||||
|
})
|
||||||
|
|
||||||
|
mockWorkflowStore.nodes = [upstreamNode, currentNode]
|
||||||
|
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
|
||||||
|
|
||||||
|
render(<NodeConfigPanel node={currentNode} />)
|
||||||
|
|
||||||
|
// Variable selector should show upstream variables
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
|
||||||
|
|
||||||
|
expect(screen.getByText('user_input')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should insert variable into prompt template', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const node = createMockNode({ type: 'llm' })
|
||||||
|
|
||||||
|
render(<NodeConfigPanel node={node} />)
|
||||||
|
|
||||||
|
// Click variable button
|
||||||
|
await user.click(screen.getByRole('button', { name: /insert variable/i }))
|
||||||
|
await user.click(screen.getByText('user_input'))
|
||||||
|
|
||||||
|
const promptInput = screen.getByLabelText(/prompt/i)
|
||||||
|
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataset Components (`dataset/`)
|
||||||
|
|
||||||
|
Dataset components handle file uploads, data display, and search/filter operations.
|
||||||
|
|
||||||
|
### Key Test Areas
|
||||||
|
|
||||||
|
1. **File Upload**
|
||||||
|
1. **File Type Validation**
|
||||||
|
1. **Pagination**
|
||||||
|
1. **Search & Filtering**
|
||||||
|
1. **Data Format Handling**
|
||||||
|
|
||||||
|
### Example: Document Uploader
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
import DocumentUploader from './document-uploader'
|
||||||
|
|
||||||
|
vi.mock('@/service/datasets', () => ({
|
||||||
|
uploadDocument: vi.fn(),
|
||||||
|
parseDocument: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
import * as datasetService from '@/service/datasets'
|
||||||
|
const mockedService = vi.mocked(datasetService)
|
||||||
|
|
||||||
|
describe('DocumentUploader', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('File Upload', () => {
|
||||||
|
it('should accept valid file types', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const onUpload = vi.fn()
|
||||||
|
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
||||||
|
|
||||||
|
render(<DocumentUploader onUpload={onUpload} />)
|
||||||
|
|
||||||
|
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||||
|
const input = screen.getByLabelText(/upload/i)
|
||||||
|
|
||||||
|
await user.upload(input, file)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
|
||||||
|
expect.any(FormData)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should reject invalid file types', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(<DocumentUploader />)
|
||||||
|
|
||||||
|
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
|
||||||
|
const input = screen.getByLabelText(/upload/i)
|
||||||
|
|
||||||
|
await user.upload(input, file)
|
||||||
|
|
||||||
|
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
|
||||||
|
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show upload progress', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
// Mock upload with progress
|
||||||
|
mockedService.uploadDocument.mockImplementation(() => {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
setTimeout(() => resolve({ id: 'doc-1' }), 100)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
render(<DocumentUploader />)
|
||||||
|
|
||||||
|
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||||
|
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||||
|
|
||||||
|
expect(screen.getByRole('progressbar')).toBeInTheDocument()
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Error Handling', () => {
|
||||||
|
it('should handle upload failure', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
|
||||||
|
|
||||||
|
render(<DocumentUploader />)
|
||||||
|
|
||||||
|
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||||
|
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should allow retry after failure', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
mockedService.uploadDocument
|
||||||
|
.mockRejectedValueOnce(new Error('Network error'))
|
||||||
|
.mockResolvedValueOnce({ id: 'doc-1' })
|
||||||
|
|
||||||
|
render(<DocumentUploader />)
|
||||||
|
|
||||||
|
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||||
|
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
await user.click(screen.getByRole('button', { name: /retry/i }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Document List with Pagination
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
describe('DocumentList', () => {
|
||||||
|
describe('Pagination', () => {
|
||||||
|
it('should load first page on mount', async () => {
|
||||||
|
mockedService.getDocuments.mockResolvedValue({
|
||||||
|
data: [{ id: '1', name: 'Doc 1' }],
|
||||||
|
total: 50,
|
||||||
|
page: 1,
|
||||||
|
pageSize: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
render(<DocumentList datasetId="ds-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Doc 1')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should navigate to next page', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
mockedService.getDocuments.mockResolvedValue({
|
||||||
|
data: [{ id: '1', name: 'Doc 1' }],
|
||||||
|
total: 50,
|
||||||
|
page: 1,
|
||||||
|
pageSize: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
render(<DocumentList datasetId="ds-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Doc 1')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
mockedService.getDocuments.mockResolvedValue({
|
||||||
|
data: [{ id: '11', name: 'Doc 11' }],
|
||||||
|
total: 50,
|
||||||
|
page: 2,
|
||||||
|
pageSize: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
await user.click(screen.getByRole('button', { name: /next/i }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Doc 11')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Search & Filtering', () => {
|
||||||
|
it('should filter by search query', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
vi.useFakeTimers()
|
||||||
|
|
||||||
|
render(<DocumentList datasetId="ds-1" />)
|
||||||
|
|
||||||
|
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
||||||
|
|
||||||
|
// Debounce
|
||||||
|
vi.advanceTimersByTime(300)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
||||||
|
'ds-1',
|
||||||
|
expect.objectContaining({ search: 'test query' })
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.useRealTimers()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Components (`app/configuration/`, `config/`)
|
||||||
|
|
||||||
|
Configuration components handle forms, validation, and data persistence.
|
||||||
|
|
||||||
|
### Key Test Areas
|
||||||
|
|
||||||
|
1. **Form Validation**
|
||||||
|
1. **Save/Reset**
|
||||||
|
1. **Required vs Optional Fields**
|
||||||
|
1. **Configuration Persistence**
|
||||||
|
1. **Error Feedback**
|
||||||
|
|
||||||
|
### Example: App Configuration Form
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
import AppConfigForm from './app-config-form'
|
||||||
|
|
||||||
|
vi.mock('@/service/apps', () => ({
|
||||||
|
updateAppConfig: vi.fn(),
|
||||||
|
getAppConfig: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
import * as appService from '@/service/apps'
|
||||||
|
const mockedService = vi.mocked(appService)
|
||||||
|
|
||||||
|
describe('AppConfigForm', () => {
|
||||||
|
const defaultConfig = {
|
||||||
|
name: 'My App',
|
||||||
|
description: '',
|
||||||
|
icon: 'default',
|
||||||
|
openingStatement: '',
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Form Validation', () => {
|
||||||
|
it('should require app name', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(<AppConfigForm appId="app-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||||
|
})
|
||||||
|
|
||||||
|
// Clear name field
|
||||||
|
await user.clear(screen.getByLabelText(/name/i))
|
||||||
|
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||||
|
|
||||||
|
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
|
||||||
|
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate name length', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(<AppConfigForm appId="app-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Enter very long name
|
||||||
|
await user.clear(screen.getByLabelText(/name/i))
|
||||||
|
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
|
||||||
|
|
||||||
|
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should allow empty optional fields', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
mockedService.updateAppConfig.mockResolvedValue({ success: true })
|
||||||
|
|
||||||
|
render(<AppConfigForm appId="app-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||||
|
})
|
||||||
|
|
||||||
|
// Leave description empty (optional)
|
||||||
|
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockedService.updateAppConfig).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Save/Reset Functionality', () => {
|
||||||
|
it('should save configuration', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
mockedService.updateAppConfig.mockResolvedValue({ success: true })
|
||||||
|
|
||||||
|
render(<AppConfigForm appId="app-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||||
|
})
|
||||||
|
|
||||||
|
await user.clear(screen.getByLabelText(/name/i))
|
||||||
|
await user.type(screen.getByLabelText(/name/i), 'Updated App')
|
||||||
|
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
|
||||||
|
'app-1',
|
||||||
|
expect.objectContaining({ name: 'Updated App' })
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should reset to default values', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(<AppConfigForm appId="app-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||||
|
})
|
||||||
|
|
||||||
|
// Make changes
|
||||||
|
await user.clear(screen.getByLabelText(/name/i))
|
||||||
|
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
|
||||||
|
|
||||||
|
// Reset
|
||||||
|
await user.click(screen.getByRole('button', { name: /reset/i }))
|
||||||
|
|
||||||
|
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show unsaved changes warning', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(<AppConfigForm appId="app-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||||
|
})
|
||||||
|
|
||||||
|
// Make changes
|
||||||
|
await user.type(screen.getByLabelText(/name/i), ' Updated')
|
||||||
|
|
||||||
|
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Error Handling', () => {
|
||||||
|
it('should show error on save failure', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
|
||||||
|
|
||||||
|
render(<AppConfigForm appId="app-1" />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||||
|
})
|
||||||
|
|
||||||
|
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,343 @@
|
||||||
|
# Mocking Guide for Dify Frontend Tests
|
||||||
|
|
||||||
|
## ⚠️ Important: What NOT to Mock
|
||||||
|
|
||||||
|
### DO NOT Mock Base Components
|
||||||
|
|
||||||
|
**Never mock components from `@/app/components/base/`** such as:
|
||||||
|
|
||||||
|
- `Loading`, `Spinner`
|
||||||
|
- `Button`, `Input`, `Select`
|
||||||
|
- `Tooltip`, `Modal`, `Dropdown`
|
||||||
|
- `Icon`, `Badge`, `Tag`
|
||||||
|
|
||||||
|
**Why?**
|
||||||
|
|
||||||
|
- Base components will have their own dedicated tests
|
||||||
|
- Mocking them creates false positives (tests pass but real integration fails)
|
||||||
|
- Using real components tests actual integration behavior
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ WRONG: Don't mock base components
|
||||||
|
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||||
|
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||||
|
|
||||||
|
// ✅ CORRECT: Import and use real base components
|
||||||
|
import Loading from '@/app/components/base/loading'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
// They will render normally in tests
|
||||||
|
```
|
||||||
|
|
||||||
|
### What TO Mock
|
||||||
|
|
||||||
|
Only mock these categories:
|
||||||
|
|
||||||
|
1. **API services** (`@/service/*`) - Network calls
|
||||||
|
1. **Complex context providers** - When setup is too difficult
|
||||||
|
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
|
||||||
|
1. **i18n** - Always mock to return keys
|
||||||
|
|
||||||
|
## Mock Placement
|
||||||
|
|
||||||
|
| Location | Purpose |
|
||||||
|
|----------|---------|
|
||||||
|
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
|
||||||
|
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
|
||||||
|
| Test file | Test-specific mocks, inline with `vi.mock()` |
|
||||||
|
|
||||||
|
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
|
||||||
|
|
||||||
|
## Essential Mocks
|
||||||
|
|
||||||
|
### 1. i18n (Auto-loaded via Global Mock)
|
||||||
|
|
||||||
|
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||||
|
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
||||||
|
|
||||||
|
For tests requiring custom translations, override the mock:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
vi.mock('react-i18next', () => ({
|
||||||
|
useTranslation: () => ({
|
||||||
|
t: (key: string) => {
|
||||||
|
const translations: Record<string, string> = {
|
||||||
|
'my.custom.key': 'Custom translation',
|
||||||
|
}
|
||||||
|
return translations[key] || key
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Next.js Router
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const mockPush = vi.fn()
|
||||||
|
const mockReplace = vi.fn()
|
||||||
|
|
||||||
|
vi.mock('next/navigation', () => ({
|
||||||
|
useRouter: () => ({
|
||||||
|
push: mockPush,
|
||||||
|
replace: mockReplace,
|
||||||
|
back: vi.fn(),
|
||||||
|
prefetch: vi.fn(),
|
||||||
|
}),
|
||||||
|
usePathname: () => '/current-path',
|
||||||
|
useSearchParams: () => new URLSearchParams('?key=value'),
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('Component', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should navigate on click', () => {
|
||||||
|
render(<Component />)
|
||||||
|
fireEvent.click(screen.getByRole('button'))
|
||||||
|
expect(mockPush).toHaveBeenCalledWith('/expected-path')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Portal Components (with Shared State)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ⚠️ Important: Use shared state for components that depend on each other
|
||||||
|
let mockPortalOpenState = false
|
||||||
|
|
||||||
|
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||||
|
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
||||||
|
mockPortalOpenState = open || false // Update shared state
|
||||||
|
return <div data-testid="portal" data-open={open}>{children}</div>
|
||||||
|
},
|
||||||
|
PortalToFollowElemContent: ({ children }: any) => {
|
||||||
|
// ✅ Matches actual: returns null when portal is closed
|
||||||
|
if (!mockPortalOpenState) return null
|
||||||
|
return <div data-testid="portal-content">{children}</div>
|
||||||
|
},
|
||||||
|
PortalToFollowElemTrigger: ({ children }: any) => (
|
||||||
|
<div data-testid="portal-trigger">{children}</div>
|
||||||
|
),
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('Component', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
mockPortalOpenState = false // ✅ Reset shared state
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. API Service Mocks
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import * as api from '@/service/api'
|
||||||
|
|
||||||
|
vi.mock('@/service/api')
|
||||||
|
|
||||||
|
const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
|
describe('Component', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
// Setup default mock implementation
|
||||||
|
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show data on success', async () => {
|
||||||
|
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
|
||||||
|
|
||||||
|
render(<Component />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('1')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show error on failure', async () => {
|
||||||
|
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||||
|
|
||||||
|
render(<Component />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. HTTP Mocking with Nock
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import nock from 'nock'
|
||||||
|
|
||||||
|
const GITHUB_HOST = 'https://api.github.com'
|
||||||
|
const GITHUB_PATH = '/repos/owner/repo'
|
||||||
|
|
||||||
|
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
|
||||||
|
return nock(GITHUB_HOST)
|
||||||
|
.get(GITHUB_PATH)
|
||||||
|
.delay(delayMs)
|
||||||
|
.reply(status, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('GithubComponent', () => {
|
||||||
|
afterEach(() => {
|
||||||
|
nock.cleanAll()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should display repo info', async () => {
|
||||||
|
mockGithubApi(200, { name: 'dify', stars: 1000 })
|
||||||
|
|
||||||
|
render(<GithubComponent />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('dify')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle API error', async () => {
|
||||||
|
mockGithubApi(500, { message: 'Server error' })
|
||||||
|
|
||||||
|
render(<GithubComponent />)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Context Providers
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { ProviderContext } from '@/context/provider-context'
|
||||||
|
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
|
||||||
|
|
||||||
|
describe('Component with Context', () => {
|
||||||
|
it('should render for free plan', () => {
|
||||||
|
const mockContext = createMockPlan('sandbox')
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ProviderContext.Provider value={mockContext}>
|
||||||
|
<Component />
|
||||||
|
</ProviderContext.Provider>
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByText('Upgrade')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should render for pro plan', () => {
|
||||||
|
const mockContext = createMockPlan('professional')
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ProviderContext.Provider value={mockContext}>
|
||||||
|
<Component />
|
||||||
|
</ProviderContext.Provider>
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. React Query
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||||
|
|
||||||
|
const createTestQueryClient = () => new QueryClient({
|
||||||
|
defaultOptions: {
|
||||||
|
queries: { retry: false },
|
||||||
|
mutations: { retry: false },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||||
|
const queryClient = createTestQueryClient()
|
||||||
|
return render(
|
||||||
|
<QueryClientProvider client={queryClient}>
|
||||||
|
{ui}
|
||||||
|
</QueryClientProvider>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Mock Best Practices
|
||||||
|
|
||||||
|
### ✅ DO
|
||||||
|
|
||||||
|
1. **Use real base components** - Import from `@/app/components/base/` directly
|
||||||
|
1. **Use real project components** - Prefer importing over mocking
|
||||||
|
1. **Reset mocks in `beforeEach`**, not `afterEach`
|
||||||
|
1. **Match actual component behavior** in mocks (when mocking is necessary)
|
||||||
|
1. **Use factory functions** for complex mock data
|
||||||
|
1. **Import actual types** for type safety
|
||||||
|
1. **Reset shared mock state** in `beforeEach`
|
||||||
|
|
||||||
|
### ❌ DON'T
|
||||||
|
|
||||||
|
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||||
|
1. Don't mock components you can import directly
|
||||||
|
1. Don't create overly simplified mocks that miss conditional logic
|
||||||
|
1. Don't forget to clean up nock after each test
|
||||||
|
1. Don't use `any` types in mocks without necessity
|
||||||
|
|
||||||
|
### Mock Decision Tree
|
||||||
|
|
||||||
|
```
|
||||||
|
Need to use a component in test?
|
||||||
|
│
|
||||||
|
├─ Is it from @/app/components/base/*?
|
||||||
|
│ └─ YES → Import real component, DO NOT mock
|
||||||
|
│
|
||||||
|
├─ Is it a project component?
|
||||||
|
│ └─ YES → Prefer importing real component
|
||||||
|
│ Only mock if setup is extremely complex
|
||||||
|
│
|
||||||
|
├─ Is it an API service (@/service/*)?
|
||||||
|
│ └─ YES → Mock it
|
||||||
|
│
|
||||||
|
├─ Is it a third-party lib with side effects?
|
||||||
|
│ └─ YES → Mock it (next/navigation, external SDKs)
|
||||||
|
│
|
||||||
|
└─ Is it i18n?
|
||||||
|
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
|
||||||
|
```
|
||||||
|
|
||||||
|
## Factory Function Pattern
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// __mocks__/data-factories.ts
|
||||||
|
import type { User, Project } from '@/types'
|
||||||
|
|
||||||
|
export const createMockUser = (overrides: Partial<User> = {}): User => ({
|
||||||
|
id: 'user-1',
|
||||||
|
name: 'Test User',
|
||||||
|
email: 'test@example.com',
|
||||||
|
role: 'member',
|
||||||
|
createdAt: new Date().toISOString(),
|
||||||
|
...overrides,
|
||||||
|
})
|
||||||
|
|
||||||
|
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
|
||||||
|
id: 'project-1',
|
||||||
|
name: 'Test Project',
|
||||||
|
description: 'A test project',
|
||||||
|
owner: createMockUser(),
|
||||||
|
members: [],
|
||||||
|
createdAt: new Date().toISOString(),
|
||||||
|
...overrides,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Usage in tests
|
||||||
|
it('should display project owner', () => {
|
||||||
|
const project = createMockProject({
|
||||||
|
owner: createMockUser({ name: 'John Doe' }),
|
||||||
|
})
|
||||||
|
|
||||||
|
render(<ProjectCard project={project} />)
|
||||||
|
expect(screen.getByText('John Doe')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,269 @@
|
||||||
|
# Testing Workflow Guide
|
||||||
|
|
||||||
|
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
|
||||||
|
|
||||||
|
## Scope Clarification
|
||||||
|
|
||||||
|
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
|
||||||
|
|
||||||
|
| Scope | Rule |
|
||||||
|
|-------|------|
|
||||||
|
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
|
||||||
|
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
|
||||||
|
|
||||||
|
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
|
||||||
|
|
||||||
|
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
|
||||||
|
|
||||||
|
### Why Incremental?
|
||||||
|
|
||||||
|
| Batch Approach (❌) | Incremental Approach (✅) |
|
||||||
|
|---------------------|---------------------------|
|
||||||
|
| Generate 5+ tests at once | Generate 1 test at a time |
|
||||||
|
| Run tests only at the end | Run test immediately after each file |
|
||||||
|
| Multiple failures compound | Single point of failure, easy to debug |
|
||||||
|
| Hard to identify root cause | Clear cause-effect relationship |
|
||||||
|
| Mock issues affect many files | Mock issues caught early |
|
||||||
|
| Messy git history | Clean, atomic commits possible |
|
||||||
|
|
||||||
|
## Single File Workflow
|
||||||
|
|
||||||
|
When testing a **single component, hook, or utility**:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Read source code completely
|
||||||
|
2. Run `pnpm analyze-component <path>` (if available)
|
||||||
|
3. Check complexity score and features detected
|
||||||
|
4. Write the test file
|
||||||
|
5. Run test: `pnpm test <file>.spec.tsx`
|
||||||
|
6. Fix any failures
|
||||||
|
7. Verify coverage meets goals (100% function, >95% branch)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory/Multi-File Workflow (MUST FOLLOW)
|
||||||
|
|
||||||
|
When testing a **directory or multiple files**, follow this strict workflow:
|
||||||
|
|
||||||
|
### Step 1: Analyze and Plan
|
||||||
|
|
||||||
|
1. **List all files** that need tests in the directory
|
||||||
|
1. **Categorize by complexity**:
|
||||||
|
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
|
||||||
|
- 🟡 **Medium**: Components with state, effects, or event handlers
|
||||||
|
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
|
||||||
|
1. **Order by dependency**: Test dependencies before dependents
|
||||||
|
1. **Create a todo list** to track progress
|
||||||
|
|
||||||
|
### Step 2: Determine Processing Order
|
||||||
|
|
||||||
|
Process files in this recommended order:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Utility functions (simplest, no React)
|
||||||
|
2. Custom hooks (isolated logic)
|
||||||
|
3. Simple presentational components (few/no props)
|
||||||
|
4. Medium complexity components (state, effects)
|
||||||
|
5. Complex components (API, routing, many deps)
|
||||||
|
6. Container/index components (integration tests - last)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
|
||||||
|
- Simpler files help establish mock patterns
|
||||||
|
- Hooks used by components should be tested first
|
||||||
|
- Integration tests (index files) depend on child components working
|
||||||
|
|
||||||
|
### Step 3: Process Each File Incrementally
|
||||||
|
|
||||||
|
**For EACH file in the ordered list:**
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────┐
|
||||||
|
│ 1. Write test file │
|
||||||
|
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||||
|
│ 3. If FAIL → Fix immediately, re-run │
|
||||||
|
│ 4. If PASS → Mark complete in todo list │
|
||||||
|
│ 5. ONLY THEN proceed to next file │
|
||||||
|
└─────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**DO NOT proceed to the next file until the current one passes.**
|
||||||
|
|
||||||
|
### Step 4: Final Verification
|
||||||
|
|
||||||
|
After all individual tests pass:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests in the directory together
|
||||||
|
pnpm test path/to/directory/
|
||||||
|
|
||||||
|
# Check coverage
|
||||||
|
pnpm test:coverage path/to/directory/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Component Complexity Guidelines
|
||||||
|
|
||||||
|
Use `pnpm analyze-component <path>` to assess complexity before testing.
|
||||||
|
|
||||||
|
### 🔴 Very Complex Components (Complexity > 50)
|
||||||
|
|
||||||
|
**Consider refactoring BEFORE testing:**
|
||||||
|
|
||||||
|
- Break component into smaller, testable pieces
|
||||||
|
- Extract complex logic into custom hooks
|
||||||
|
- Separate container and presentational layers
|
||||||
|
|
||||||
|
**If testing as-is:**
|
||||||
|
|
||||||
|
- Use integration tests for complex workflows
|
||||||
|
- Use `test.each()` for data-driven testing
|
||||||
|
- Multiple `describe` blocks for organization
|
||||||
|
- Consider testing major sections separately
|
||||||
|
|
||||||
|
### 🟡 Medium Complexity (Complexity 30-50)
|
||||||
|
|
||||||
|
- Group related tests in `describe` blocks
|
||||||
|
- Test integration scenarios between internal parts
|
||||||
|
- Focus on state transitions and side effects
|
||||||
|
- Use helper functions to reduce test complexity
|
||||||
|
|
||||||
|
### 🟢 Simple Components (Complexity < 30)
|
||||||
|
|
||||||
|
- Standard test structure
|
||||||
|
- Focus on props, rendering, and edge cases
|
||||||
|
- Usually straightforward to test
|
||||||
|
|
||||||
|
### 📏 Large Files (500+ lines)
|
||||||
|
|
||||||
|
Regardless of complexity score:
|
||||||
|
|
||||||
|
- **Strongly consider refactoring** before testing
|
||||||
|
- If testing as-is, test major sections separately
|
||||||
|
- Create helper functions for test setup
|
||||||
|
- May need multiple test files
|
||||||
|
|
||||||
|
## Todo List Format
|
||||||
|
|
||||||
|
When testing multiple files, use a todo list like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
Testing: path/to/directory/
|
||||||
|
|
||||||
|
Ordered by complexity (simple → complex):
|
||||||
|
|
||||||
|
☐ utils/helper.ts [utility, simple]
|
||||||
|
☐ hooks/use-custom-hook.ts [hook, simple]
|
||||||
|
☐ empty-state.tsx [component, simple]
|
||||||
|
☐ item-card.tsx [component, medium]
|
||||||
|
☐ list.tsx [component, complex]
|
||||||
|
☐ index.tsx [integration]
|
||||||
|
|
||||||
|
Progress: 0/6 complete
|
||||||
|
```
|
||||||
|
|
||||||
|
Update status as you complete each:
|
||||||
|
|
||||||
|
- ☐ → ⏳ (in progress)
|
||||||
|
- ⏳ → ✅ (complete and verified)
|
||||||
|
- ⏳ → ❌ (blocked, needs attention)
|
||||||
|
|
||||||
|
## When to Stop and Verify
|
||||||
|
|
||||||
|
**Always run tests after:**
|
||||||
|
|
||||||
|
- Completing a test file
|
||||||
|
- Making changes to fix a failure
|
||||||
|
- Modifying shared mocks
|
||||||
|
- Updating test utilities or helpers
|
||||||
|
|
||||||
|
**Signs you should pause:**
|
||||||
|
|
||||||
|
- More than 2 consecutive test failures
|
||||||
|
- Mock-related errors appearing
|
||||||
|
- Unclear why a test is failing
|
||||||
|
- Test passing but coverage unexpectedly low
|
||||||
|
|
||||||
|
## Common Pitfalls to Avoid
|
||||||
|
|
||||||
|
### ❌ Don't: Generate Everything First
|
||||||
|
|
||||||
|
```
|
||||||
|
# BAD: Writing all files then testing
|
||||||
|
Write component-a.spec.tsx
|
||||||
|
Write component-b.spec.tsx
|
||||||
|
Write component-c.spec.tsx
|
||||||
|
Write component-d.spec.tsx
|
||||||
|
Run pnpm test ← Multiple failures, hard to debug
|
||||||
|
```
|
||||||
|
|
||||||
|
### ✅ Do: Verify Each Step
|
||||||
|
|
||||||
|
```
|
||||||
|
# GOOD: Incremental with verification
|
||||||
|
Write component-a.spec.tsx
|
||||||
|
Run pnpm test component-a.spec.tsx ✅
|
||||||
|
Write component-b.spec.tsx
|
||||||
|
Run pnpm test component-b.spec.tsx ✅
|
||||||
|
...continue...
|
||||||
|
```
|
||||||
|
|
||||||
|
### ❌ Don't: Skip Verification for "Simple" Components
|
||||||
|
|
||||||
|
Even simple components can have:
|
||||||
|
|
||||||
|
- Import errors
|
||||||
|
- Missing mock setup
|
||||||
|
- Incorrect assumptions about props
|
||||||
|
|
||||||
|
**Always verify, regardless of perceived simplicity.**
|
||||||
|
|
||||||
|
### ❌ Don't: Continue When Tests Fail
|
||||||
|
|
||||||
|
Failing tests compound:
|
||||||
|
|
||||||
|
- A mock issue in file A affects files B, C, D
|
||||||
|
- Fixing A later requires revisiting all dependent tests
|
||||||
|
- Time wasted on debugging cascading failures
|
||||||
|
|
||||||
|
**Fix failures immediately before proceeding.**
|
||||||
|
|
||||||
|
## Integration with Claude's Todo Feature
|
||||||
|
|
||||||
|
When using Claude for multi-file testing:
|
||||||
|
|
||||||
|
1. **Ask Claude to create a todo list** before starting
|
||||||
|
1. **Request one file at a time** or ensure Claude processes incrementally
|
||||||
|
1. **Verify each test passes** before asking for the next
|
||||||
|
1. **Mark todos complete** as you progress
|
||||||
|
|
||||||
|
Example prompt:
|
||||||
|
|
||||||
|
```
|
||||||
|
Test all components in `path/to/directory/`.
|
||||||
|
First, analyze the directory and create a todo list ordered by complexity.
|
||||||
|
Then, process ONE file at a time, waiting for my confirmation that tests pass
|
||||||
|
before proceeding to the next.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Summary Checklist
|
||||||
|
|
||||||
|
Before starting multi-file testing:
|
||||||
|
|
||||||
|
- [ ] Listed all files needing tests
|
||||||
|
- [ ] Ordered by complexity (simple → complex)
|
||||||
|
- [ ] Created todo list for tracking
|
||||||
|
- [ ] Understand dependencies between files
|
||||||
|
|
||||||
|
During testing:
|
||||||
|
|
||||||
|
- [ ] Processing ONE file at a time
|
||||||
|
- [ ] Running tests after EACH file
|
||||||
|
- [ ] Fixing failures BEFORE proceeding
|
||||||
|
- [ ] Updating todo list progress
|
||||||
|
|
||||||
|
After completion:
|
||||||
|
|
||||||
|
- [ ] All individual tests pass
|
||||||
|
- [ ] Full directory test run passes
|
||||||
|
- [ ] Coverage goals met
|
||||||
|
- [ ] Todo list shows all complete
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
../.claude/skills
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
[run]
|
||||||
|
omit =
|
||||||
|
api/tests/*
|
||||||
|
api/migrations/*
|
||||||
|
api/core/rag/datasource/vdb/*
|
||||||
|
|
@ -6,6 +6,9 @@
|
||||||
"context": "..",
|
"context": "..",
|
||||||
"dockerfile": "Dockerfile"
|
"dockerfile": "Dockerfile"
|
||||||
},
|
},
|
||||||
|
"mounts": [
|
||||||
|
"source=dify-dev-tmp,target=/tmp,type=volume"
|
||||||
|
],
|
||||||
"features": {
|
"features": {
|
||||||
"ghcr.io/devcontainers/features/node:1": {
|
"ghcr.io/devcontainers/features/node:1": {
|
||||||
"nodeGypDependencies": true,
|
"nodeGypDependencies": true,
|
||||||
|
|
@ -34,19 +37,13 @@
|
||||||
},
|
},
|
||||||
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
||||||
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
||||||
|
|
||||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||||
// "features": {},
|
// "features": {},
|
||||||
|
|
||||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||||
// "forwardPorts": [],
|
// "forwardPorts": [],
|
||||||
|
|
||||||
// Use 'postCreateCommand' to run commands after the container is created.
|
// Use 'postCreateCommand' to run commands after the container is created.
|
||||||
// "postCreateCommand": "python --version",
|
// "postCreateCommand": "python --version",
|
||||||
|
|
||||||
// Configure tool-specific properties.
|
// Configure tool-specific properties.
|
||||||
// "customizations": {},
|
// "customizations": {},
|
||||||
|
|
||||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||||
// "remoteUser": "root"
|
}
|
||||||
}
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
WORKSPACE_ROOT=$(pwd)
|
WORKSPACE_ROOT=$(pwd)
|
||||||
|
|
||||||
|
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
|
||||||
corepack enable
|
corepack enable
|
||||||
cd web && pnpm install
|
cd web && pnpm install
|
||||||
pipx install uv
|
pipx install uv
|
||||||
|
|
||||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||||
|
|
|
||||||
|
|
@ -6,221 +6,244 @@
|
||||||
|
|
||||||
* @crazywoola @laipz8200 @Yeuoly
|
* @crazywoola @laipz8200 @Yeuoly
|
||||||
|
|
||||||
|
# CODEOWNERS file
|
||||||
|
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||||
|
|
||||||
|
# Docs
|
||||||
|
/docs/ @crazywoola
|
||||||
|
|
||||||
# Backend (default owner, more specific rules below will override)
|
# Backend (default owner, more specific rules below will override)
|
||||||
api/ @QuantumGhost
|
/api/ @QuantumGhost
|
||||||
|
|
||||||
|
# Backend - MCP
|
||||||
|
/api/core/mcp/ @Nov1c444
|
||||||
|
/api/core/entities/mcp_provider.py @Nov1c444
|
||||||
|
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||||
|
/api/controllers/mcp/ @Nov1c444
|
||||||
|
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||||
|
/api/tests/**/*mcp* @Nov1c444
|
||||||
|
|
||||||
# Backend - Workflow - Engine (Core graph execution engine)
|
# Backend - Workflow - Engine (Core graph execution engine)
|
||||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
/api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||||
|
|
||||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||||
api/core/workflow/nodes/agent/ @Nov1c444
|
/api/core/workflow/nodes/agent/ @Nov1c444
|
||||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
/api/core/workflow/nodes/iteration/ @Nov1c444
|
||||||
api/core/workflow/nodes/loop/ @Nov1c444
|
/api/core/workflow/nodes/loop/ @Nov1c444
|
||||||
api/core/workflow/nodes/llm/ @Nov1c444
|
/api/core/workflow/nodes/llm/ @Nov1c444
|
||||||
|
|
||||||
# Backend - RAG (Retrieval Augmented Generation)
|
# Backend - RAG (Retrieval Augmented Generation)
|
||||||
api/core/rag/ @JohnJyong
|
/api/core/rag/ @JohnJyong
|
||||||
api/services/rag_pipeline/ @JohnJyong
|
/api/services/rag_pipeline/ @JohnJyong
|
||||||
api/services/dataset_service.py @JohnJyong
|
/api/services/dataset_service.py @JohnJyong
|
||||||
api/services/knowledge_service.py @JohnJyong
|
/api/services/knowledge_service.py @JohnJyong
|
||||||
api/services/external_knowledge_service.py @JohnJyong
|
/api/services/external_knowledge_service.py @JohnJyong
|
||||||
api/services/hit_testing_service.py @JohnJyong
|
/api/services/hit_testing_service.py @JohnJyong
|
||||||
api/services/metadata_service.py @JohnJyong
|
/api/services/metadata_service.py @JohnJyong
|
||||||
api/services/vector_service.py @JohnJyong
|
/api/services/vector_service.py @JohnJyong
|
||||||
api/services/entities/knowledge_entities/ @JohnJyong
|
/api/services/entities/knowledge_entities/ @JohnJyong
|
||||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
/api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||||
api/controllers/console/datasets/ @JohnJyong
|
/api/controllers/console/datasets/ @JohnJyong
|
||||||
api/controllers/service_api/dataset/ @JohnJyong
|
/api/controllers/service_api/dataset/ @JohnJyong
|
||||||
api/models/dataset.py @JohnJyong
|
/api/models/dataset.py @JohnJyong
|
||||||
api/tasks/rag_pipeline/ @JohnJyong
|
/api/tasks/rag_pipeline/ @JohnJyong
|
||||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
/api/tasks/add_document_to_index_task.py @JohnJyong
|
||||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
/api/tasks/batch_clean_document_task.py @JohnJyong
|
||||||
api/tasks/clean_document_task.py @JohnJyong
|
/api/tasks/clean_document_task.py @JohnJyong
|
||||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
/api/tasks/clean_notion_document_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_task.py @JohnJyong
|
/api/tasks/document_indexing_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
/api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
/api/tasks/document_indexing_update_task.py @JohnJyong
|
||||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
/api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
/api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
/api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
/api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
/api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
/api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
/api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
/api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
/api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
/api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
/api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||||
api/tasks/clean_dataset_task.py @JohnJyong
|
/api/tasks/clean_dataset_task.py @JohnJyong
|
||||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
/api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||||
|
|
||||||
# Backend - Plugins
|
# Backend - Plugins
|
||||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||||
|
|
||||||
# Backend - Trigger/Schedule/Webhook
|
# Backend - Trigger/Schedule/Webhook
|
||||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
/api/controllers/trigger/ @Mairuis @Yeuoly
|
||||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||||
api/core/trigger/ @Mairuis @Yeuoly
|
/api/core/trigger/ @Mairuis @Yeuoly
|
||||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||||
api/services/trigger/ @Mairuis @Yeuoly
|
/api/services/trigger/ @Mairuis @Yeuoly
|
||||||
api/models/trigger.py @Mairuis @Yeuoly
|
/api/models/trigger.py @Mairuis @Yeuoly
|
||||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
/api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||||
|
|
||||||
# Backend - Async Workflow
|
# Backend - Async Workflow
|
||||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
/api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||||
|
|
||||||
# Backend - Billing
|
# Backend - Billing
|
||||||
api/services/billing_service.py @hj24 @zyssyz123
|
/api/services/billing_service.py @hj24 @zyssyz123
|
||||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
/api/controllers/console/billing/ @hj24 @zyssyz123
|
||||||
|
|
||||||
# Backend - Enterprise
|
# Backend - Enterprise
|
||||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
/api/configs/enterprise/ @GarfieldDai @GareArc
|
||||||
api/services/enterprise/ @GarfieldDai @GareArc
|
/api/services/enterprise/ @GarfieldDai @GareArc
|
||||||
api/services/feature_service.py @GarfieldDai @GareArc
|
/api/services/feature_service.py @GarfieldDai @GareArc
|
||||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
/api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
/api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||||
|
|
||||||
# Backend - Database Migrations
|
# Backend - Database Migrations
|
||||||
api/migrations/ @snakevash @laipz8200
|
/api/migrations/ @snakevash @laipz8200 @MRZHUH
|
||||||
|
|
||||||
|
# Backend - Vector DB Middleware
|
||||||
|
/api/configs/middleware/vdb/* @JohnJyong
|
||||||
|
|
||||||
# Frontend
|
# Frontend
|
||||||
web/ @iamjoel
|
/web/ @iamjoel
|
||||||
|
|
||||||
|
# Frontend - Web Tests
|
||||||
|
/.github/workflows/web-tests.yml @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Orchestration
|
# Frontend - App - Orchestration
|
||||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
/web/app/components/workflow/ @iamjoel @zxhlyh
|
||||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
/web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
/web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - WebApp - Chat
|
# Frontend - WebApp - Chat
|
||||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
/web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - WebApp - Completion
|
# Frontend - WebApp - Completion
|
||||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
/web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - App - List and Creation
|
# Frontend - App - List and Creation
|
||||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
/web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - API Documentation
|
# Frontend - App - API Documentation
|
||||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Logs and Annotations
|
# Frontend - App - Logs and Annotations
|
||||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
/web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Monitoring
|
# Frontend - App - Monitoring
|
||||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
/web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Settings
|
# Frontend - App - Settings
|
||||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - RAG - Hit Testing
|
# Frontend - RAG - Hit Testing
|
||||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - RAG - List and Creation
|
# Frontend - RAG - List and Creation
|
||||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
/web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
/web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
/web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - RAG - Documents List
|
# Frontend - RAG - Documents List
|
||||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Segments List
|
# Frontend - RAG - Segments List
|
||||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Settings
|
# Frontend - RAG - Settings
|
||||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
/web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - Ecosystem - Plugins
|
# Frontend - Ecosystem - Plugins
|
||||||
web/app/components/plugins/ @iamjoel @zhsama
|
/web/app/components/plugins/ @iamjoel @zhsama
|
||||||
|
|
||||||
# Frontend - Ecosystem - Tools
|
# Frontend - Ecosystem - Tools
|
||||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
/web/app/components/tools/ @iamjoel @Yessenia-d
|
||||||
|
|
||||||
# Frontend - Ecosystem - MarketPlace
|
# Frontend - Ecosystem - MarketPlace
|
||||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||||
|
|
||||||
# Frontend - Login and Registration
|
# Frontend - Login and Registration
|
||||||
web/app/signin/ @douxc @iamjoel
|
/web/app/signin/ @douxc @iamjoel
|
||||||
web/app/signup/ @douxc @iamjoel
|
/web/app/signup/ @douxc @iamjoel
|
||||||
web/app/reset-password/ @douxc @iamjoel
|
/web/app/reset-password/ @douxc @iamjoel
|
||||||
web/app/install/ @douxc @iamjoel
|
/web/app/install/ @douxc @iamjoel
|
||||||
web/app/init/ @douxc @iamjoel
|
/web/app/init/ @douxc @iamjoel
|
||||||
web/app/forgot-password/ @douxc @iamjoel
|
/web/app/forgot-password/ @douxc @iamjoel
|
||||||
web/app/account/ @douxc @iamjoel
|
/web/app/account/ @douxc @iamjoel
|
||||||
|
|
||||||
# Frontend - Service Authentication
|
# Frontend - Service Authentication
|
||||||
web/service/base.ts @douxc @iamjoel
|
/web/service/base.ts @douxc @iamjoel
|
||||||
|
|
||||||
# Frontend - WebApp Authentication and Access Control
|
# Frontend - WebApp Authentication and Access Control
|
||||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
/web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
/web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||||
|
|
||||||
# Frontend - Explore Page
|
# Frontend - Explore Page
|
||||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
/web/app/components/explore/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Personal Settings
|
# Frontend - Personal Settings
|
||||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Analytics
|
# Frontend - Analytics
|
||||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
/web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Base Components
|
# Frontend - Base Components
|
||||||
web/app/components/base/ @iamjoel @zxhlyh
|
/web/app/components/base/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Utils and Hooks
|
# Frontend - Utils and Hooks
|
||||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||||
web/utils/time.ts @iamjoel @zxhlyh
|
/web/utils/time.ts @iamjoel @zxhlyh
|
||||||
web/utils/format.ts @iamjoel @zxhlyh
|
/web/utils/format.ts @iamjoel @zxhlyh
|
||||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
/web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Billing and Education
|
# Frontend - Billing and Education
|
||||||
web/app/components/billing/ @iamjoel @zxhlyh
|
/web/app/components/billing/ @iamjoel @zxhlyh
|
||||||
web/app/education-apply/ @iamjoel @zxhlyh
|
/web/app/education-apply/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Workspace
|
# Frontend - Workspace
|
||||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
/docker/* @laipz8200
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
name: "✨ Refactor"
|
name: "✨ Refactor or Chore"
|
||||||
description: Refactor existing code for improved readability and maintainability.
|
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
|
||||||
title: "[Chore/Refactor] "
|
title: "[Refactor/Chore] "
|
||||||
labels:
|
|
||||||
- refactor
|
|
||||||
body:
|
body:
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
attributes:
|
attributes:
|
||||||
|
|
@ -11,7 +9,7 @@ body:
|
||||||
options:
|
options:
|
||||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||||
required: true
|
required: true
|
||||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||||
required: true
|
required: true
|
||||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||||
required: true
|
required: true
|
||||||
|
|
@ -25,14 +23,14 @@ body:
|
||||||
id: description
|
id: description
|
||||||
attributes:
|
attributes:
|
||||||
label: Description
|
label: Description
|
||||||
placeholder: "Describe the refactor you are proposing."
|
placeholder: "Describe the refactor or chore you are proposing."
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: motivation
|
id: motivation
|
||||||
attributes:
|
attributes:
|
||||||
label: Motivation
|
label: Motivation
|
||||||
placeholder: "Explain why this refactor is necessary."
|
placeholder: "Explain why this refactor or chore is necessary."
|
||||||
validations:
|
validations:
|
||||||
required: false
|
required: false
|
||||||
- type: textarea
|
- type: textarea
|
||||||
|
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
name: "👾 Tracker"
|
|
||||||
description: For inner usages, please do not use this template.
|
|
||||||
title: "[Tracker] "
|
|
||||||
labels:
|
|
||||||
- tracker
|
|
||||||
body:
|
|
||||||
- type: textarea
|
|
||||||
id: content
|
|
||||||
attributes:
|
|
||||||
label: Blockers
|
|
||||||
placeholder: "- [ ] ..."
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
# Copilot Instructions
|
|
||||||
|
|
||||||
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
|
|
||||||
|
|
||||||
Key reminders:
|
|
||||||
|
|
||||||
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
|
|
||||||
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
|
|
||||||
- Target >95% line and branch coverage and 100% function/statement coverage.
|
|
||||||
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
|
|
||||||
|
|
||||||
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
|
|
||||||
|
|
@ -71,18 +71,18 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||||
|
|
||||||
- name: Run Workflow
|
- name: Run API Tests
|
||||||
run: uv run --project api bash dev/pytest/pytest_workflow.sh
|
env:
|
||||||
|
STORAGE_TYPE: opendal
|
||||||
- name: Run Tool
|
OPENDAL_SCHEME: fs
|
||||||
run: uv run --project api bash dev/pytest/pytest_tools.sh
|
OPENDAL_FS_ROOT: /tmp/dify-storage
|
||||||
|
|
||||||
- name: Run TestContainers
|
|
||||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
|
||||||
|
|
||||||
- name: Run Unit tests
|
|
||||||
run: |
|
run: |
|
||||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
uv run --project api pytest \
|
||||||
|
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||||
|
api/tests/integration_tests/workflow \
|
||||||
|
api/tests/integration_tests/tools \
|
||||||
|
api/tests/test_containers_integration_tests \
|
||||||
|
api/tests/unit_tests
|
||||||
|
|
||||||
- name: Coverage Summary
|
- name: Coverage Summary
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -93,5 +93,12 @@ jobs:
|
||||||
# Create a detailed coverage summary
|
# Create a detailed coverage summary
|
||||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
{
|
||||||
|
echo ""
|
||||||
|
echo "<details><summary>File-level coverage (click to expand)</summary>"
|
||||||
|
echo ""
|
||||||
|
echo '```'
|
||||||
|
uv run --project api coverage report -m
|
||||||
|
echo '```'
|
||||||
|
echo "</details>"
|
||||||
|
} >> $GITHUB_STEP_SUMMARY
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,27 @@ jobs:
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
- name: Check Docker Compose inputs
|
||||||
- uses: astral-sh/setup-uv@v6
|
id: docker-compose-changes
|
||||||
|
uses: tj-actions/changed-files@v46
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
docker/generate_docker_compose
|
||||||
|
docker/.env.example
|
||||||
|
docker/docker-compose-template.yaml
|
||||||
|
docker/docker-compose.yaml
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Generate Docker Compose
|
||||||
|
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||||
|
run: |
|
||||||
|
cd docker
|
||||||
|
./generate_docker_compose
|
||||||
|
|
||||||
- run: |
|
- run: |
|
||||||
cd api
|
cd api
|
||||||
uv sync --dev
|
uv sync --dev
|
||||||
|
|
@ -35,10 +52,11 @@ jobs:
|
||||||
|
|
||||||
- name: ast-grep
|
- name: ast-grep
|
||||||
run: |
|
run: |
|
||||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
# ast-grep exits 1 if no matches are found; allow idempotent runs.
|
||||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
|
||||||
|
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
|
||||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||||
cat > /tmp/optional-rule.yml << 'EOF'
|
cat > /tmp/optional-rule.yml << 'EOF'
|
||||||
id: convert-optional-to-union
|
id: convert-optional-to-union
|
||||||
|
|
@ -56,35 +74,14 @@ jobs:
|
||||||
pattern: $T
|
pattern: $T
|
||||||
fix: $T | None
|
fix: $T | None
|
||||||
EOF
|
EOF
|
||||||
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
|
uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
|
||||||
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
|
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
|
||||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||||
find . -name "*.py.bak" -type f -delete
|
find . -name "*.py.bak" -type f -delete
|
||||||
|
|
||||||
|
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||||
- name: mdformat
|
- name: mdformat
|
||||||
run: |
|
run: |
|
||||||
uvx mdformat .
|
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
||||||
|
|
||||||
- name: Install pnpm
|
|
||||||
uses: pnpm/action-setup@v4
|
|
||||||
with:
|
|
||||||
package_json_file: web/package.json
|
|
||||||
run_install: false
|
|
||||||
|
|
||||||
- name: Setup NodeJS
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: 22
|
|
||||||
cache: pnpm
|
|
||||||
cache-dependency-path: ./web/package.json
|
|
||||||
|
|
||||||
- name: Web dependencies
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: oxlint
|
|
||||||
working-directory: ./web
|
|
||||||
run: |
|
|
||||||
pnpx oxlint --fix
|
|
||||||
|
|
||||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
name: Semantic Pull Request
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- edited
|
||||||
|
- reopened
|
||||||
|
- synchronize
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Validate PR title
|
||||||
|
permissions:
|
||||||
|
pull-requests: read
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Check title
|
||||||
|
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
@ -90,7 +90,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
cache: pnpm
|
cache: pnpm
|
||||||
cache-dependency-path: ./web/package.json
|
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||||
|
|
||||||
- name: Web dependencies
|
- name: Web dependencies
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
|
@ -106,37 +106,7 @@ jobs:
|
||||||
- name: Web type check
|
- name: Web type check
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run type-check
|
run: pnpm run type-check:tsgo
|
||||||
|
|
||||||
docker-compose-template:
|
|
||||||
name: Docker Compose Template
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Check changed files
|
|
||||||
id: changed-files
|
|
||||||
uses: tj-actions/changed-files@v46
|
|
||||||
with:
|
|
||||||
files: |
|
|
||||||
docker/generate_docker_compose
|
|
||||||
docker/.env.example
|
|
||||||
docker/docker-compose-template.yaml
|
|
||||||
docker/docker-compose.yaml
|
|
||||||
|
|
||||||
- name: Generate Docker Compose
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
run: |
|
|
||||||
cd docker
|
|
||||||
./generate_docker_compose
|
|
||||||
|
|
||||||
- name: Check for changes
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
run: git diff --exit-code
|
|
||||||
|
|
||||||
superlinter:
|
superlinter:
|
||||||
name: SuperLinter
|
name: SuperLinter
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
name: Check i18n Files and Create PR
|
name: Translate i18n Files Based on English
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
|
@ -55,7 +55,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
node-version: 'lts/*'
|
node-version: 'lts/*'
|
||||||
cache: pnpm
|
cache: pnpm
|
||||||
cache-dependency-path: ./web/package.json
|
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
|
|
@ -67,25 +67,19 @@ jobs:
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||||
|
|
||||||
- name: Generate i18n type definitions
|
|
||||||
if: env.FILES_CHANGED == 'true'
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm run gen:i18n-types
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
uses: peter-evans/create-pull-request@v6
|
uses: peter-evans/create-pull-request@v6
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
commit-message: 'chore(i18n): update translations based on en-US changes'
|
commit-message: 'chore(i18n): update translations based on en-US changes'
|
||||||
title: 'chore(i18n): translate i18n files and update type definitions'
|
title: 'chore(i18n): translate i18n files based on en-US changes'
|
||||||
body: |
|
body: |
|
||||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
This PR was automatically created to update i18n translation files based on changes in en-US locale.
|
||||||
|
|
||||||
**Triggered by:** ${{ github.sha }}
|
**Triggered by:** ${{ github.sha }}
|
||||||
|
|
||||||
**Changes included:**
|
**Changes included:**
|
||||||
- Updated translation files for all locales
|
- Updated translation files for all locales
|
||||||
- Regenerated TypeScript type definitions for type safety
|
|
||||||
branch: chore/automated-i18n-updates-${{ github.sha }}
|
branch: chore/automated-i18n-updates-${{ github.sha }}
|
||||||
delete-branch: true
|
delete-branch: true
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
|
shell: bash
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|
@ -21,14 +22,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Check changed files
|
|
||||||
id: changed-files
|
|
||||||
uses: tj-actions/changed-files@v46
|
|
||||||
with:
|
|
||||||
files: web/**
|
|
||||||
|
|
||||||
- name: Install pnpm
|
- name: Install pnpm
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
uses: pnpm/action-setup@v4
|
uses: pnpm/action-setup@v4
|
||||||
with:
|
with:
|
||||||
package_json_file: web/package.json
|
package_json_file: web/package.json
|
||||||
|
|
@ -36,23 +30,339 @@ jobs:
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
cache: pnpm
|
cache: pnpm
|
||||||
cache-dependency-path: ./web/package.json
|
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Check i18n types synchronization
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm run check:i18n-types
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
run: pnpm test:coverage
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm test
|
- name: Coverage Summary
|
||||||
|
if: always()
|
||||||
|
id: coverage-summary
|
||||||
|
run: |
|
||||||
|
set -eo pipefail
|
||||||
|
|
||||||
|
COVERAGE_FILE="coverage/coverage-final.json"
|
||||||
|
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||||
|
|
||||||
|
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||||
|
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||||
|
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||||
|
const fs = require('fs');
|
||||||
|
const path = require('path');
|
||||||
|
let libCoverage = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
libCoverage = require('istanbul-lib-coverage');
|
||||||
|
} catch (error) {
|
||||||
|
libCoverage = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||||
|
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||||
|
|
||||||
|
const hasSummary = fs.existsSync(summaryPath);
|
||||||
|
const hasFinal = fs.existsSync(finalPath);
|
||||||
|
|
||||||
|
if (!hasSummary && !hasFinal) {
|
||||||
|
console.log('### Test Coverage Summary :test_tube:');
|
||||||
|
console.log('');
|
||||||
|
console.log('No coverage data found.');
|
||||||
|
process.exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const summary = hasSummary
|
||||||
|
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||||
|
: null;
|
||||||
|
const coverage = hasFinal
|
||||||
|
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||||
|
: null;
|
||||||
|
|
||||||
|
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||||
|
const lineHits = {};
|
||||||
|
|
||||||
|
if (!statementMap || !statementHits) {
|
||||||
|
return lineHits;
|
||||||
|
}
|
||||||
|
|
||||||
|
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||||
|
const line = statement?.start?.line;
|
||||||
|
if (!line) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const hits = statementHits[key] ?? 0;
|
||||||
|
const previous = lineHits[line];
|
||||||
|
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||||
|
});
|
||||||
|
|
||||||
|
return lineHits;
|
||||||
|
};
|
||||||
|
|
||||||
|
const getFileCoverage = (entry) => (
|
||||||
|
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||||
|
);
|
||||||
|
|
||||||
|
const getLineHits = (entry, fileCoverage) => {
|
||||||
|
const lineHits = entry.l ?? {};
|
||||||
|
if (Object.keys(lineHits).length > 0) {
|
||||||
|
return lineHits;
|
||||||
|
}
|
||||||
|
if (fileCoverage) {
|
||||||
|
return fileCoverage.getLineCoverage();
|
||||||
|
}
|
||||||
|
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||||
|
};
|
||||||
|
|
||||||
|
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||||
|
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||||
|
return Object.entries(lineHits)
|
||||||
|
.filter(([, count]) => count === 0)
|
||||||
|
.map(([line]) => Number(line))
|
||||||
|
.sort((a, b) => a - b);
|
||||||
|
}
|
||||||
|
if (fileCoverage) {
|
||||||
|
return fileCoverage.getUncoveredLines();
|
||||||
|
}
|
||||||
|
return [];
|
||||||
|
};
|
||||||
|
|
||||||
|
const totals = {
|
||||||
|
lines: { covered: 0, total: 0 },
|
||||||
|
statements: { covered: 0, total: 0 },
|
||||||
|
branches: { covered: 0, total: 0 },
|
||||||
|
functions: { covered: 0, total: 0 },
|
||||||
|
};
|
||||||
|
const fileSummaries = [];
|
||||||
|
|
||||||
|
if (summary) {
|
||||||
|
const totalEntry = summary.total ?? {};
|
||||||
|
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||||
|
if (totalEntry[key]) {
|
||||||
|
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||||
|
totals[key].total = totalEntry[key].total ?? 0;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Object.entries(summary)
|
||||||
|
.filter(([file]) => file !== 'total')
|
||||||
|
.forEach(([file, data]) => {
|
||||||
|
fileSummaries.push({
|
||||||
|
file,
|
||||||
|
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||||
|
lines: {
|
||||||
|
covered: data.lines?.covered ?? 0,
|
||||||
|
total: data.lines?.total ?? 0,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else if (coverage) {
|
||||||
|
Object.entries(coverage).forEach(([file, entry]) => {
|
||||||
|
const fileCoverage = getFileCoverage(entry);
|
||||||
|
const lineHits = getLineHits(entry, fileCoverage);
|
||||||
|
const statementHits = entry.s ?? {};
|
||||||
|
const branchHits = entry.b ?? {};
|
||||||
|
const functionHits = entry.f ?? {};
|
||||||
|
|
||||||
|
const lineTotal = Object.keys(lineHits).length;
|
||||||
|
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||||
|
|
||||||
|
const statementTotal = Object.keys(statementHits).length;
|
||||||
|
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||||
|
|
||||||
|
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||||
|
const branchCovered = Object.values(branchHits).reduce(
|
||||||
|
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
const functionTotal = Object.keys(functionHits).length;
|
||||||
|
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||||
|
|
||||||
|
totals.lines.total += lineTotal;
|
||||||
|
totals.lines.covered += lineCovered;
|
||||||
|
totals.statements.total += statementTotal;
|
||||||
|
totals.statements.covered += statementCovered;
|
||||||
|
totals.branches.total += branchTotal;
|
||||||
|
totals.branches.covered += branchCovered;
|
||||||
|
totals.functions.total += functionTotal;
|
||||||
|
totals.functions.covered += functionCovered;
|
||||||
|
|
||||||
|
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||||
|
|
||||||
|
fileSummaries.push({
|
||||||
|
file,
|
||||||
|
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||||
|
lines: {
|
||||||
|
covered: lineCovered || statementCovered,
|
||||||
|
total: lineTotal || statementTotal,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||||
|
|
||||||
|
console.log('### Test Coverage Summary :test_tube:');
|
||||||
|
console.log('');
|
||||||
|
console.log('| Metric | Coverage | Covered / Total |');
|
||||||
|
console.log('|--------|----------|-----------------|');
|
||||||
|
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||||
|
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||||
|
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||||
|
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||||
|
|
||||||
|
console.log('');
|
||||||
|
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||||
|
console.log('');
|
||||||
|
console.log('```');
|
||||||
|
fileSummaries
|
||||||
|
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||||
|
.slice(0, 25)
|
||||||
|
.forEach(({ file, pct, lines }) => {
|
||||||
|
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||||
|
});
|
||||||
|
console.log('```');
|
||||||
|
console.log('</details>');
|
||||||
|
|
||||||
|
if (coverage) {
|
||||||
|
const pctValue = (covered, tot) => {
|
||||||
|
if (tot === 0) {
|
||||||
|
return '0';
|
||||||
|
}
|
||||||
|
return ((covered / tot) * 100)
|
||||||
|
.toFixed(2)
|
||||||
|
.replace(/\.?0+$/, '');
|
||||||
|
};
|
||||||
|
|
||||||
|
const formatLineRanges = (lines) => {
|
||||||
|
if (lines.length === 0) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
const ranges = [];
|
||||||
|
let start = lines[0];
|
||||||
|
let end = lines[0];
|
||||||
|
|
||||||
|
for (let i = 1; i < lines.length; i += 1) {
|
||||||
|
const current = lines[i];
|
||||||
|
if (current === end + 1) {
|
||||||
|
end = current;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||||
|
start = current;
|
||||||
|
end = current;
|
||||||
|
}
|
||||||
|
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||||
|
return ranges.join(',');
|
||||||
|
};
|
||||||
|
|
||||||
|
const tableTotals = {
|
||||||
|
statements: { covered: 0, total: 0 },
|
||||||
|
branches: { covered: 0, total: 0 },
|
||||||
|
functions: { covered: 0, total: 0 },
|
||||||
|
lines: { covered: 0, total: 0 },
|
||||||
|
};
|
||||||
|
const tableRows = Object.entries(coverage)
|
||||||
|
.map(([file, entry]) => {
|
||||||
|
const fileCoverage = getFileCoverage(entry);
|
||||||
|
const lineHits = getLineHits(entry, fileCoverage);
|
||||||
|
const statementHits = entry.s ?? {};
|
||||||
|
const branchHits = entry.b ?? {};
|
||||||
|
const functionHits = entry.f ?? {};
|
||||||
|
|
||||||
|
const lineTotal = Object.keys(lineHits).length;
|
||||||
|
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||||
|
const statementTotal = Object.keys(statementHits).length;
|
||||||
|
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||||
|
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||||
|
const branchCovered = Object.values(branchHits).reduce(
|
||||||
|
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
const functionTotal = Object.keys(functionHits).length;
|
||||||
|
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||||
|
|
||||||
|
tableTotals.lines.total += lineTotal;
|
||||||
|
tableTotals.lines.covered += lineCovered;
|
||||||
|
tableTotals.statements.total += statementTotal;
|
||||||
|
tableTotals.statements.covered += statementCovered;
|
||||||
|
tableTotals.branches.total += branchTotal;
|
||||||
|
tableTotals.branches.covered += branchCovered;
|
||||||
|
tableTotals.functions.total += functionTotal;
|
||||||
|
tableTotals.functions.covered += functionCovered;
|
||||||
|
|
||||||
|
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||||
|
|
||||||
|
const filePath = entry.path ?? file;
|
||||||
|
const relativePath = path.isAbsolute(filePath)
|
||||||
|
? path.relative(process.cwd(), filePath)
|
||||||
|
: filePath;
|
||||||
|
|
||||||
|
return {
|
||||||
|
file: relativePath || file,
|
||||||
|
statements: pctValue(statementCovered, statementTotal),
|
||||||
|
branches: pctValue(branchCovered, branchTotal),
|
||||||
|
functions: pctValue(functionCovered, functionTotal),
|
||||||
|
lines: pctValue(lineCovered, lineTotal),
|
||||||
|
uncovered: formatLineRanges(uncoveredLines),
|
||||||
|
};
|
||||||
|
})
|
||||||
|
.sort((a, b) => a.file.localeCompare(b.file));
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{ key: 'file', header: 'File', align: 'left' },
|
||||||
|
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||||
|
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||||
|
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||||
|
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||||
|
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const allFilesRow = {
|
||||||
|
file: 'All files',
|
||||||
|
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||||
|
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||||
|
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||||
|
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||||
|
uncovered: '',
|
||||||
|
};
|
||||||
|
|
||||||
|
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||||
|
const formatRow = (row) => `| ${columns
|
||||||
|
.map(({ key }) => String(row[key] ?? ''))
|
||||||
|
.join(' | ')} |`;
|
||||||
|
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||||
|
const dividerRow = `| ${columns
|
||||||
|
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||||
|
.join(' | ')} |`;
|
||||||
|
|
||||||
|
console.log('');
|
||||||
|
console.log('<details><summary>Vitest coverage table</summary>');
|
||||||
|
console.log('');
|
||||||
|
console.log(headerRow);
|
||||||
|
console.log(dividerRow);
|
||||||
|
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||||
|
console.log('</details>');
|
||||||
|
}
|
||||||
|
NODE
|
||||||
|
|
||||||
|
- name: Upload Coverage Artifact
|
||||||
|
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: web-coverage-report
|
||||||
|
path: web/coverage
|
||||||
|
retention-days: 30
|
||||||
|
if-no-files-found: error
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,6 @@ pyrightconfig.json
|
||||||
.idea/'
|
.idea/'
|
||||||
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
web/.vscode/settings.json
|
|
||||||
|
|
||||||
# Intellij IDEA Files
|
# Intellij IDEA Files
|
||||||
.idea/*
|
.idea/*
|
||||||
|
|
@ -189,12 +188,14 @@ docker/volumes/matrixone/*
|
||||||
docker/volumes/mysql/*
|
docker/volumes/mysql/*
|
||||||
docker/volumes/seekdb/*
|
docker/volumes/seekdb/*
|
||||||
!docker/volumes/oceanbase/init.d
|
!docker/volumes/oceanbase/init.d
|
||||||
|
docker/volumes/iris/*
|
||||||
|
|
||||||
docker/nginx/conf.d/default.conf
|
docker/nginx/conf.d/default.conf
|
||||||
docker/nginx/ssl/*
|
docker/nginx/ssl/*
|
||||||
!docker/nginx/ssl/.gitkeep
|
!docker/nginx/ssl/.gitkeep
|
||||||
docker/middleware.env
|
docker/middleware.env
|
||||||
docker/docker-compose.override.yaml
|
docker/docker-compose.override.yaml
|
||||||
|
docker/env-backup/*
|
||||||
|
|
||||||
sdks/python-client/build
|
sdks/python-client/build
|
||||||
sdks/python-client/dist
|
sdks/python-client/dist
|
||||||
|
|
@ -204,7 +205,6 @@ sdks/python-client/dify_client.egg-info
|
||||||
!.vscode/launch.json.template
|
!.vscode/launch.json.template
|
||||||
!.vscode/README.md
|
!.vscode/README.md
|
||||||
api/.vscode
|
api/.vscode
|
||||||
web/.vscode
|
|
||||||
# vscode Code History Extension
|
# vscode Code History Extension
|
||||||
.history
|
.history
|
||||||
|
|
||||||
|
|
@ -219,15 +219,6 @@ plugins.jsonl
|
||||||
# mise
|
# mise
|
||||||
mise.toml
|
mise.toml
|
||||||
|
|
||||||
# Next.js build output
|
|
||||||
.next/
|
|
||||||
|
|
||||||
# PWA generated files
|
|
||||||
web/public/sw.js
|
|
||||||
web/public/sw.js.map
|
|
||||||
web/public/workbox-*.js
|
|
||||||
web/public/workbox-*.js.map
|
|
||||||
web/public/fallback-*.js
|
|
||||||
|
|
||||||
# AI Assistant
|
# AI Assistant
|
||||||
.roo/
|
.roo/
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@
|
||||||
"-c",
|
"-c",
|
||||||
"1",
|
"1",
|
||||||
"-Q",
|
"-Q",
|
||||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
|
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
|
||||||
"--loglevel",
|
"--loglevel",
|
||||||
"INFO"
|
"INFO"
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
# Windsurf Testing Rules
|
|
||||||
|
|
||||||
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
|
|
||||||
- Honor every requirement in that document when generating or accepting tests.
|
|
||||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
|
||||||
|
|
@ -24,8 +24,8 @@ The codebase is split into:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd web
|
cd web
|
||||||
pnpm lint
|
|
||||||
pnpm lint:fix
|
pnpm lint:fix
|
||||||
|
pnpm type-check:tsgo
|
||||||
pnpm test
|
pnpm test
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -39,7 +39,7 @@ pnpm test
|
||||||
## Language Style
|
## Language Style
|
||||||
|
|
||||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||||
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
|
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||||
|
|
||||||
## General Practices
|
## General Practices
|
||||||
|
|
||||||
|
|
|
||||||
13
README.md
13
README.md
|
|
@ -139,6 +139,19 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||||
|
|
||||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||||
|
|
||||||
|
#### Customizing Suggested Questions
|
||||||
|
|
||||||
|
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# In your .env file
|
||||||
|
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
|
||||||
|
SUGGESTED_QUESTIONS_MAX_TOKENS=512
|
||||||
|
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
|
||||||
|
|
||||||
### Metrics Monitoring with Grafana
|
### Metrics Monitoring with Grafana
|
||||||
|
|
||||||
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||||
ALIYUN_OSS_REGION=your-region
|
ALIYUN_OSS_REGION=your-region
|
||||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||||
ALIYUN_OSS_PATH=your-path
|
ALIYUN_OSS_PATH=your-path
|
||||||
|
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||||
|
|
||||||
# Google Storage configuration
|
# Google Storage configuration
|
||||||
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
||||||
|
|
@ -133,6 +134,7 @@ HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||||
HUAWEI_OBS_SERVER=your-server-url
|
HUAWEI_OBS_SERVER=your-server-url
|
||||||
|
HUAWEI_OBS_PATH_STYLE=false
|
||||||
|
|
||||||
# Baidu OBS Storage Configuration
|
# Baidu OBS Storage Configuration
|
||||||
BAIDU_OBS_BUCKET_NAME=your-bucket-name
|
BAIDU_OBS_BUCKET_NAME=your-bucket-name
|
||||||
|
|
@ -543,6 +545,25 @@ APP_MAX_EXECUTION_TIME=1200
|
||||||
APP_DEFAULT_ACTIVE_REQUESTS=0
|
APP_DEFAULT_ACTIVE_REQUESTS=0
|
||||||
APP_MAX_ACTIVE_REQUESTS=0
|
APP_MAX_ACTIVE_REQUESTS=0
|
||||||
|
|
||||||
|
# Aliyun SLS Logstore Configuration
|
||||||
|
# Aliyun Access Key ID
|
||||||
|
ALIYUN_SLS_ACCESS_KEY_ID=
|
||||||
|
# Aliyun Access Key Secret
|
||||||
|
ALIYUN_SLS_ACCESS_KEY_SECRET=
|
||||||
|
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
|
||||||
|
ALIYUN_SLS_ENDPOINT=
|
||||||
|
# Aliyun SLS Region (e.g., cn-hangzhou)
|
||||||
|
ALIYUN_SLS_REGION=
|
||||||
|
# Aliyun SLS Project Name
|
||||||
|
ALIYUN_SLS_PROJECT_NAME=
|
||||||
|
# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage)
|
||||||
|
ALIYUN_SLS_LOGSTORE_TTL=365
|
||||||
|
# Enable dual-write to both SLS LogStore and SQL database (default: false)
|
||||||
|
LOGSTORE_DUAL_WRITE_ENABLED=false
|
||||||
|
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
|
||||||
|
# Useful for migration scenarios where historical data exists only in SQL database
|
||||||
|
LOGSTORE_DUAL_READ_ENABLED=true
|
||||||
|
|
||||||
# Celery beat configuration
|
# Celery beat configuration
|
||||||
CELERY_BEAT_SCHEDULER_TIME=1
|
CELERY_BEAT_SCHEDULER_TIME=1
|
||||||
|
|
||||||
|
|
@ -633,8 +654,45 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||||
|
|
||||||
|
# Suggested Questions After Answer Configuration
|
||||||
|
# These environment variables allow customization of the suggested questions feature
|
||||||
|
#
|
||||||
|
# Custom prompt for generating suggested questions (optional)
|
||||||
|
# If not set, uses the default prompt that generates 3 questions under 20 characters each
|
||||||
|
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
|
||||||
|
# SUGGESTED_QUESTIONS_PROMPT=
|
||||||
|
|
||||||
|
# Maximum number of tokens for suggested questions generation (default: 256)
|
||||||
|
# Adjust this value for longer questions or more questions
|
||||||
|
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
|
||||||
|
|
||||||
|
# Temperature for suggested questions generation (default: 0.0)
|
||||||
|
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
|
||||||
|
# SUGGESTED_QUESTIONS_TEMPERATURE=0
|
||||||
|
|
||||||
# Tenant isolated task queue configuration
|
# Tenant isolated task queue configuration
|
||||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||||
|
|
||||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||||
|
|
||||||
|
# Multimodal knowledgebase limit
|
||||||
|
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
|
||||||
|
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
|
||||||
|
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
|
||||||
|
IMAGE_FILE_BATCH_LIMIT=10
|
||||||
|
|
||||||
|
# Maximum allowed CSV file size for annotation import in megabytes
|
||||||
|
ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2
|
||||||
|
#Maximum number of annotation records allowed in a single import
|
||||||
|
ANNOTATION_IMPORT_MAX_RECORDS=10000
|
||||||
|
# Minimum number of annotation records required in a single import
|
||||||
|
ANNOTATION_IMPORT_MIN_RECORDS=1
|
||||||
|
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
|
||||||
|
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
||||||
|
# Maximum number of concurrent annotation import tasks per tenant
|
||||||
|
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||||
|
# Sandbox expired records clean configuration
|
||||||
|
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||||
|
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||||
|
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||||
|
|
|
||||||
|
|
@ -36,17 +36,20 @@ select = [
|
||||||
"UP", # pyupgrade rules
|
"UP", # pyupgrade rules
|
||||||
"W191", # tab-indentation
|
"W191", # tab-indentation
|
||||||
"W605", # invalid-escape-sequence
|
"W605", # invalid-escape-sequence
|
||||||
|
"G001", # don't use str format to logging messages
|
||||||
|
"G003", # don't use + in logging messages
|
||||||
|
"G004", # don't use f-strings to format logging messages
|
||||||
|
"UP042", # use StrEnum,
|
||||||
|
"S110", # disallow the try-except-pass pattern.
|
||||||
|
|
||||||
# security related linting rules
|
# security related linting rules
|
||||||
# RCE proctection (sort of)
|
# RCE proctection (sort of)
|
||||||
"S102", # exec-builtin, disallow use of `exec`
|
"S102", # exec-builtin, disallow use of `exec`
|
||||||
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
|
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
|
||||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||||
"S311", # suspicious-non-cryptographic-random-usage
|
"S311", # suspicious-non-cryptographic-random-usage,
|
||||||
"G001", # don't use str format to logging messages
|
|
||||||
"G003", # don't use + in logging messages
|
|
||||||
"G004", # don't use f-strings to format logging messages
|
|
||||||
"UP042", # use StrEnum
|
|
||||||
]
|
]
|
||||||
|
|
||||||
ignore = [
|
ignore = [
|
||||||
|
|
@ -91,18 +94,16 @@ ignore = [
|
||||||
"configs/*" = [
|
"configs/*" = [
|
||||||
"N802", # invalid-function-name
|
"N802", # invalid-function-name
|
||||||
]
|
]
|
||||||
"core/model_runtime/callbacks/base_callback.py" = [
|
"core/model_runtime/callbacks/base_callback.py" = ["T201"]
|
||||||
"T201",
|
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
|
||||||
]
|
|
||||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
|
||||||
"T201",
|
|
||||||
]
|
|
||||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||||
"N803", # invalid-argument-name
|
"N803", # invalid-argument-name
|
||||||
]
|
]
|
||||||
"tests/*" = [
|
"tests/*" = [
|
||||||
"F811", # redefined-while-unused
|
"F811", # redefined-while-unused
|
||||||
"T201", # allow print in tests
|
"T201", # allow print in tests,
|
||||||
|
"S110", # allow ignoring exceptions in tests code (currently)
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[lint.pyflakes]
|
[lint.pyflakes]
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@
|
||||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
|
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||||
```
|
```
|
||||||
|
|
||||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from opentelemetry.trace import get_current_span
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from contexts.wrapper import RecyclableContextVar
|
from contexts.wrapper import RecyclableContextVar
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
|
@ -26,8 +28,25 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||||
# add an unique identifier to each request
|
# add an unique identifier to each request
|
||||||
RecyclableContextVar.increment_thread_recycles()
|
RecyclableContextVar.increment_thread_recycles()
|
||||||
|
|
||||||
|
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
|
||||||
|
@dify_app.after_request
|
||||||
|
def add_trace_id_header(response):
|
||||||
|
try:
|
||||||
|
span = get_current_span()
|
||||||
|
ctx = span.get_span_context() if span else None
|
||||||
|
if ctx and ctx.is_valid:
|
||||||
|
trace_id_hex = format(ctx.trace_id, "032x")
|
||||||
|
# Avoid duplicates if some middleware added it
|
||||||
|
if "X-Trace-Id" not in response.headers:
|
||||||
|
response.headers["X-Trace-Id"] = trace_id_hex
|
||||||
|
except Exception:
|
||||||
|
# Never break the response due to tracing header injection
|
||||||
|
logger.warning("Failed to add trace ID to response header", exc_info=True)
|
||||||
|
return response
|
||||||
|
|
||||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||||
_ = before_request
|
_ = before_request
|
||||||
|
_ = add_trace_id_header
|
||||||
|
|
||||||
return dify_app
|
return dify_app
|
||||||
|
|
||||||
|
|
@ -51,10 +70,12 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_commands,
|
ext_commands,
|
||||||
ext_compress,
|
ext_compress,
|
||||||
ext_database,
|
ext_database,
|
||||||
|
ext_forward_refs,
|
||||||
ext_hosting_provider,
|
ext_hosting_provider,
|
||||||
ext_import_modules,
|
ext_import_modules,
|
||||||
ext_logging,
|
ext_logging,
|
||||||
ext_login,
|
ext_login,
|
||||||
|
ext_logstore,
|
||||||
ext_mail,
|
ext_mail,
|
||||||
ext_migrate,
|
ext_migrate,
|
||||||
ext_orjson,
|
ext_orjson,
|
||||||
|
|
@ -63,6 +84,7 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_redis,
|
ext_redis,
|
||||||
ext_request_logging,
|
ext_request_logging,
|
||||||
ext_sentry,
|
ext_sentry,
|
||||||
|
ext_session_factory,
|
||||||
ext_set_secretkey,
|
ext_set_secretkey,
|
||||||
ext_storage,
|
ext_storage,
|
||||||
ext_timezone,
|
ext_timezone,
|
||||||
|
|
@ -75,6 +97,7 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_warnings,
|
ext_warnings,
|
||||||
ext_import_modules,
|
ext_import_modules,
|
||||||
ext_orjson,
|
ext_orjson,
|
||||||
|
ext_forward_refs,
|
||||||
ext_set_secretkey,
|
ext_set_secretkey,
|
||||||
ext_compress,
|
ext_compress,
|
||||||
ext_code_based_extension,
|
ext_code_based_extension,
|
||||||
|
|
@ -83,6 +106,7 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_migrate,
|
ext_migrate,
|
||||||
ext_redis,
|
ext_redis,
|
||||||
ext_storage,
|
ext_storage,
|
||||||
|
ext_logstore, # Initialize logstore after storage, before celery
|
||||||
ext_celery,
|
ext_celery,
|
||||||
ext_login,
|
ext_login,
|
||||||
ext_mail,
|
ext_mail,
|
||||||
|
|
@ -93,6 +117,7 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_commands,
|
ext_commands,
|
||||||
ext_otel,
|
ext_otel,
|
||||||
ext_request_logging,
|
ext_request_logging,
|
||||||
|
ext_session_factory,
|
||||||
]
|
]
|
||||||
for ext in extensions:
|
for ext in extensions:
|
||||||
short_name = ext.__name__.split(".")[-1]
|
short_name = ext.__name__.split(".")[-1]
|
||||||
|
|
|
||||||
|
|
@ -1139,6 +1139,7 @@ def remove_orphaned_files_on_storage(force: bool):
|
||||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
|
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
all_files_on_storage = []
|
all_files_on_storage = []
|
||||||
for storage_path in storage_paths:
|
for storage_path in storage_paths:
|
||||||
|
|
|
||||||
|
|
@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
|
||||||
|
|
||||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||||
default=300.0,
|
default=600.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||||
|
|
@ -360,6 +360,57 @@ class FileUploadConfig(BaseSettings):
|
||||||
default=10,
|
default=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
|
||||||
|
description="Maximum number of files allowed in a image batch upload operation",
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
|
||||||
|
description="Maximum number of files allowed in a single chunk attachment",
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||||
|
description="Maximum allowed image file size for attachments in megabytes",
|
||||||
|
default=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
|
||||||
|
description="Timeout for downloading image attachments in seconds",
|
||||||
|
default=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Annotation Import Security Configurations
|
||||||
|
ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||||
|
description="Maximum allowed CSV file size for annotation import in megabytes",
|
||||||
|
default=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field(
|
||||||
|
description="Maximum number of annotation records allowed in a single import",
|
||||||
|
default=10000,
|
||||||
|
)
|
||||||
|
|
||||||
|
ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field(
|
||||||
|
description="Minimum number of annotation records required in a single import",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field(
|
||||||
|
description="Maximum number of annotation import requests per minute per tenant",
|
||||||
|
default=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
|
||||||
|
description="Maximum number of annotation import requests per hour per tenant",
|
||||||
|
default=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field(
|
||||||
|
description="Maximum number of concurrent annotation import tasks per tenant",
|
||||||
|
default=2,
|
||||||
|
)
|
||||||
|
|
||||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||||
description=(
|
description=(
|
||||||
"Comma-separated list of file extensions that are blocked from upload. "
|
"Comma-separated list of file extensions that are blocked from upload. "
|
||||||
|
|
@ -553,7 +604,10 @@ class LoggingConfig(BaseSettings):
|
||||||
|
|
||||||
LOG_FORMAT: str = Field(
|
LOG_FORMAT: str = Field(
|
||||||
description="Format string for log messages",
|
description="Format string for log messages",
|
||||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
default=(
|
||||||
|
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
|
||||||
|
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG_DATEFORMAT: str | None = Field(
|
LOG_DATEFORMAT: str | None = Field(
|
||||||
|
|
@ -1216,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||||
|
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||||
|
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||||
|
default=21,
|
||||||
|
)
|
||||||
|
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||||
|
description="Maximum number of records to process in each batch",
|
||||||
|
default=1000,
|
||||||
|
)
|
||||||
|
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||||
|
description="Retention days for sandbox expired workflow_run records and message records",
|
||||||
|
default=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FeatureConfig(
|
class FeatureConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
AppExecutionConfig,
|
AppExecutionConfig,
|
||||||
|
|
@ -1241,6 +1310,7 @@ class FeatureConfig(
|
||||||
PositionConfig,
|
PositionConfig,
|
||||||
RagEtlConfig,
|
RagEtlConfig,
|
||||||
RepositoryConfig,
|
RepositoryConfig,
|
||||||
|
SandboxExpiredRecordsCleanConfig,
|
||||||
SecurityConfig,
|
SecurityConfig,
|
||||||
TenantIsolatedTaskQueueConfig,
|
TenantIsolatedTaskQueueConfig,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig
|
||||||
from .vdb.couchbase_config import CouchbaseConfig
|
from .vdb.couchbase_config import CouchbaseConfig
|
||||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||||
|
from .vdb.iris_config import IrisVectorConfig
|
||||||
from .vdb.lindorm_config import LindormConfig
|
from .vdb.lindorm_config import LindormConfig
|
||||||
from .vdb.matrixone_config import MatrixoneConfig
|
from .vdb.matrixone_config import MatrixoneConfig
|
||||||
from .vdb.milvus_config import MilvusConfig
|
from .vdb.milvus_config import MilvusConfig
|
||||||
|
|
@ -106,7 +107,7 @@ class KeywordStoreConfig(BaseSettings):
|
||||||
|
|
||||||
class DatabaseConfig(BaseSettings):
|
class DatabaseConfig(BaseSettings):
|
||||||
# Database type selector
|
# Database type selector
|
||||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
|
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
|
||||||
description="Database type to use. OceanBase is MySQL-compatible.",
|
description="Database type to use. OceanBase is MySQL-compatible.",
|
||||||
default="postgresql",
|
default="postgresql",
|
||||||
)
|
)
|
||||||
|
|
@ -336,6 +337,7 @@ class MiddlewareConfig(
|
||||||
ChromaConfig,
|
ChromaConfig,
|
||||||
ClickzettaConfig,
|
ClickzettaConfig,
|
||||||
HuaweiCloudConfig,
|
HuaweiCloudConfig,
|
||||||
|
IrisVectorConfig,
|
||||||
MilvusConfig,
|
MilvusConfig,
|
||||||
AlibabaCloudMySQLConfig,
|
AlibabaCloudMySQLConfig,
|
||||||
MyScaleConfig,
|
MyScaleConfig,
|
||||||
|
|
|
||||||
|
|
@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
|
||||||
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ALIYUN_CLOUDBOX_ID: str | None = Field(
|
||||||
|
description="Cloudbox id for aliyun cloudbox service",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
|
||||||
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
HUAWEI_OBS_PATH_STYLE: bool = Field(
|
||||||
|
description="Flag to indicate whether to use path-style URLs for OBS requests",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
"""Configuration for InterSystems IRIS vector database."""
|
||||||
|
|
||||||
|
from pydantic import Field, PositiveInt, model_validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class IrisVectorConfig(BaseSettings):
|
||||||
|
"""Configuration settings for IRIS vector database connection and pooling."""
|
||||||
|
|
||||||
|
IRIS_HOST: str | None = Field(
|
||||||
|
description="Hostname or IP address of the IRIS server.",
|
||||||
|
default="localhost",
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
|
||||||
|
description="Port number for IRIS connection.",
|
||||||
|
default=1972,
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_USER: str | None = Field(
|
||||||
|
description="Username for IRIS authentication.",
|
||||||
|
default="_SYSTEM",
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_PASSWORD: str | None = Field(
|
||||||
|
description="Password for IRIS authentication.",
|
||||||
|
default="Dify@1234",
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_SCHEMA: str | None = Field(
|
||||||
|
description="Schema name for IRIS tables.",
|
||||||
|
default="dify",
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_DATABASE: str | None = Field(
|
||||||
|
description="Database namespace for IRIS connection.",
|
||||||
|
default="USER",
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_CONNECTION_URL: str | None = Field(
|
||||||
|
description="Full connection URL for IRIS (overrides individual fields if provided).",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_MIN_CONNECTION: PositiveInt = Field(
|
||||||
|
description="Minimum number of connections in the pool.",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_MAX_CONNECTION: PositiveInt = Field(
|
||||||
|
description="Maximum number of connections in the pool.",
|
||||||
|
default=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_TEXT_INDEX: bool = Field(
|
||||||
|
description="Enable full-text search index using %iFind.Index.Basic.",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
|
||||||
|
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
|
||||||
|
default="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
"""Validate IRIS configuration values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: Configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated configuration dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required fields are missing or pool settings are invalid
|
||||||
|
"""
|
||||||
|
# Only validate required fields if IRIS is being used as the vector store
|
||||||
|
# This allows the config to be loaded even when IRIS is not in use
|
||||||
|
|
||||||
|
# vector_store = os.environ.get("VECTOR_STORE", "")
|
||||||
|
# We rely on Pydantic defaults for required fields if they are missing from env.
|
||||||
|
# Strict existence check is removed to allow defaults to work.
|
||||||
|
|
||||||
|
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
|
||||||
|
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
|
||||||
|
if min_conn > max_conn:
|
||||||
|
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
@ -20,6 +20,7 @@ language_timezone_mapping = {
|
||||||
"sl-SI": "Europe/Ljubljana",
|
"sl-SI": "Europe/Ljubljana",
|
||||||
"th-TH": "Asia/Bangkok",
|
"th-TH": "Asia/Bangkok",
|
||||||
"id-ID": "Asia/Jakarta",
|
"id-ID": "Asia/Jakarta",
|
||||||
|
"ar-TN": "Africa/Tunis",
|
||||||
}
|
}
|
||||||
|
|
||||||
languages = list(language_timezone_mapping.keys())
|
languages = list(language_timezone_mapping.keys())
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
import os
|
||||||
|
from email.message import Message
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
|
||||||
|
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
|
||||||
|
HTML_EXTENSIONS = frozenset({"html", "htm"})
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_mime_type(mime_type: str | None) -> str:
|
||||||
|
if not mime_type:
|
||||||
|
return ""
|
||||||
|
message = Message()
|
||||||
|
message["Content-Type"] = mime_type
|
||||||
|
return message.get_content_type().strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_html_extension(extension: str | None) -> bool:
|
||||||
|
if not extension:
|
||||||
|
return False
|
||||||
|
return extension.lstrip(".").lower() in HTML_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
|
||||||
|
normalized_mime_type = _normalize_mime_type(mime_type)
|
||||||
|
if normalized_mime_type in HTML_MIME_TYPES:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if _is_html_extension(extension):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
return _is_html_extension(os.path.splitext(filename)[1])
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def enforce_download_for_html(
|
||||||
|
response: Response,
|
||||||
|
*,
|
||||||
|
mime_type: str | None,
|
||||||
|
filename: str | None,
|
||||||
|
extension: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
if not is_html_content(mime_type, filename, extension):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
encoded_filename = quote(filename)
|
||||||
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
|
else:
|
||||||
|
response.headers["Content-Disposition"] = "attachment"
|
||||||
|
|
||||||
|
response.headers["Content-Type"] = "application/octet-stream"
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
return True
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
|
||||||
|
|
||||||
|
from flask_restx import Namespace
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||||
|
"""Register a single BaseModel with a namespace for Swagger documentation."""
|
||||||
|
|
||||||
|
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
|
||||||
|
"""Register multiple BaseModels with a namespace."""
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
register_schema_model(namespace, model)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||||
|
"register_schema_model",
|
||||||
|
"register_schema_models",
|
||||||
|
]
|
||||||
|
|
@ -3,21 +3,47 @@ from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import only_edition_cloud
|
from controllers.console.wraps import only_edition_cloud
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.token import extract_access_token
|
from libs.token import extract_access_token
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, InstalledApp, RecommendedApp
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class InsertExploreAppPayload(BaseModel):
|
||||||
|
app_id: str = Field(...)
|
||||||
|
desc: str | None = None
|
||||||
|
copyright: str | None = None
|
||||||
|
privacy_policy: str | None = None
|
||||||
|
custom_disclaimer: str | None = None
|
||||||
|
language: str = Field(...)
|
||||||
|
category: str = Field(...)
|
||||||
|
position: int = Field(...)
|
||||||
|
|
||||||
|
@field_validator("language")
|
||||||
|
@classmethod
|
||||||
|
def validate_language(cls, value: str) -> str:
|
||||||
|
return supported_language(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
InsertExploreAppPayload.__name__,
|
||||||
|
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def admin_required(view: Callable[P, R]):
|
def admin_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
|
|
@ -40,59 +66,34 @@ def admin_required(view: Callable[P, R]):
|
||||||
class InsertExploreAppListApi(Resource):
|
class InsertExploreAppListApi(Resource):
|
||||||
@console_ns.doc("insert_explore_app")
|
@console_ns.doc("insert_explore_app")
|
||||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"InsertExploreAppRequest",
|
|
||||||
{
|
|
||||||
"app_id": fields.String(required=True, description="Application ID"),
|
|
||||||
"desc": fields.String(description="App description"),
|
|
||||||
"copyright": fields.String(description="Copyright information"),
|
|
||||||
"privacy_policy": fields.String(description="Privacy policy"),
|
|
||||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
|
||||||
"language": fields.String(required=True, description="Language code"),
|
|
||||||
"category": fields.String(required=True, description="App category"),
|
|
||||||
"position": fields.Integer(required=True, description="Display position"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "App updated successfully")
|
@console_ns.response(200, "App updated successfully")
|
||||||
@console_ns.response(201, "App inserted successfully")
|
@console_ns.response(201, "App inserted successfully")
|
||||||
@console_ns.response(404, "App not found")
|
@console_ns.response(404, "App not found")
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@admin_required
|
@admin_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("desc", type=str, location="json")
|
|
||||||
.add_argument("copyright", type=str, location="json")
|
|
||||||
.add_argument("privacy_policy", type=str, location="json")
|
|
||||||
.add_argument("custom_disclaimer", type=str, location="json")
|
|
||||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
|
||||||
if not app:
|
if not app:
|
||||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
raise NotFound(f"App '{payload.app_id}' is not found")
|
||||||
|
|
||||||
site = app.site
|
site = app.site
|
||||||
if not site:
|
if not site:
|
||||||
desc = args["desc"] or ""
|
desc = payload.desc or ""
|
||||||
copy_right = args["copyright"] or ""
|
copy_right = payload.copyright or ""
|
||||||
privacy_policy = args["privacy_policy"] or ""
|
privacy_policy = payload.privacy_policy or ""
|
||||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
custom_disclaimer = payload.custom_disclaimer or ""
|
||||||
else:
|
else:
|
||||||
desc = site.description or args["desc"] or ""
|
desc = site.description or payload.desc or ""
|
||||||
copy_right = site.copyright or args["copyright"] or ""
|
copy_right = site.copyright or payload.copyright or ""
|
||||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with session_factory.create_session() as session:
|
||||||
recommended_app = session.execute(
|
recommended_app = session.execute(
|
||||||
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
|
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
||||||
if not recommended_app:
|
if not recommended_app:
|
||||||
|
|
@ -102,9 +103,9 @@ class InsertExploreAppListApi(Resource):
|
||||||
copyright=copy_right,
|
copyright=copy_right,
|
||||||
privacy_policy=privacy_policy,
|
privacy_policy=privacy_policy,
|
||||||
custom_disclaimer=custom_disclaimer,
|
custom_disclaimer=custom_disclaimer,
|
||||||
language=args["language"],
|
language=payload.language,
|
||||||
category=args["category"],
|
category=payload.category,
|
||||||
position=args["position"],
|
position=payload.position,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(recommended_app)
|
db.session.add(recommended_app)
|
||||||
|
|
@ -118,9 +119,9 @@ class InsertExploreAppListApi(Resource):
|
||||||
recommended_app.copyright = copy_right
|
recommended_app.copyright = copy_right
|
||||||
recommended_app.privacy_policy = privacy_policy
|
recommended_app.privacy_policy = privacy_policy
|
||||||
recommended_app.custom_disclaimer = custom_disclaimer
|
recommended_app.custom_disclaimer = custom_disclaimer
|
||||||
recommended_app.language = args["language"]
|
recommended_app.language = payload.language
|
||||||
recommended_app.category = args["category"]
|
recommended_app.category = payload.category
|
||||||
recommended_app.position = args["position"]
|
recommended_app.position = payload.position
|
||||||
|
|
||||||
app.is_public = True
|
app.is_public = True
|
||||||
|
|
||||||
|
|
@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource):
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@admin_required
|
@admin_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
with Session(db.engine) as session:
|
with session_factory.create_session() as session:
|
||||||
recommended_app = session.execute(
|
recommended_app = session.execute(
|
||||||
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource):
|
||||||
if not recommended_app:
|
if not recommended_app:
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with session_factory.create_session() as session:
|
||||||
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
|
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||||
|
|
||||||
if app:
|
if app:
|
||||||
app.is_public = False
|
app.is_public = False
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with session_factory.create_session() as session:
|
||||||
installed_apps = (
|
installed_apps = (
|
||||||
session.execute(
|
session.execute(
|
||||||
select(InstalledApp).where(
|
select(InstalledApp).where(
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,23 @@
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||||
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
class AdvancedPromptTemplateQuery(BaseModel):
|
||||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
app_mode: str = Field(..., description="Application mode")
|
||||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
model_mode: str = Field(..., description="Model mode")
|
||||||
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
|
has_context: str = Field(default="true", description="Whether has context")
|
||||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
model_name: str = Field(..., description="Model name")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
AdvancedPromptTemplateQuery.__name__,
|
||||||
|
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,7 +25,7 @@ parser = (
|
||||||
class AdvancedPromptTemplateList(Resource):
|
class AdvancedPromptTemplateList(Resource):
|
||||||
@console_ns.doc("get_advanced_prompt_templates")
|
@console_ns.doc("get_advanced_prompt_templates")
|
||||||
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||||
)
|
)
|
||||||
|
|
@ -27,6 +34,6 @@ class AdvancedPromptTemplateList(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
args = parser.parse_args()
|
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
return AdvancedPromptTemplateService.get_prompt(args)
|
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
|
@ -8,10 +10,21 @@ from libs.login import login_required
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.agent_service import AgentService
|
from services.agent_service import AgentService
|
||||||
|
|
||||||
parser = (
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
|
||||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
class AgentLogQuery(BaseModel):
|
||||||
|
message_id: str = Field(..., description="Message UUID")
|
||||||
|
conversation_id: str = Field(..., description="Conversation UUID")
|
||||||
|
|
||||||
|
@field_validator("message_id", "conversation_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_uuid(cls, value: str) -> str:
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
|
||||||
@console_ns.doc("get_agent_logs")
|
@console_ns.doc("get_agent_logs")
|
||||||
@console_ns.doc(description="Get agent execution logs for an application")
|
@console_ns.doc(description="Get agent execution logs for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
||||||
)
|
)
|
||||||
|
|
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
|
||||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Get agent logs"""
|
"""Get agent logs"""
|
||||||
args = parser.parse_args()
|
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import abort, make_response, request
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
|
annotation_import_concurrency_limit,
|
||||||
|
annotation_import_rate_limit,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
edit_permission_required,
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
|
|
@ -21,22 +24,79 @@ from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationReplyPayload(BaseModel):
|
||||||
|
score_threshold: float = Field(..., description="Score threshold for annotation matching")
|
||||||
|
embedding_provider_name: str = Field(..., description="Embedding provider name")
|
||||||
|
embedding_model_name: str = Field(..., description="Embedding model name")
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationSettingUpdatePayload(BaseModel):
|
||||||
|
score_threshold: float = Field(..., description="Score threshold")
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationListQuery(BaseModel):
|
||||||
|
page: int = Field(default=1, ge=1, description="Page number")
|
||||||
|
limit: int = Field(default=20, ge=1, description="Page size")
|
||||||
|
keyword: str = Field(default="", description="Search keyword")
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAnnotationPayload(BaseModel):
|
||||||
|
message_id: str | None = Field(default=None, description="Message ID")
|
||||||
|
question: str | None = Field(default=None, description="Question text")
|
||||||
|
answer: str | None = Field(default=None, description="Answer text")
|
||||||
|
content: str | None = Field(default=None, description="Content text")
|
||||||
|
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
|
||||||
|
|
||||||
|
@field_validator("message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_message_id(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateAnnotationPayload(BaseModel):
|
||||||
|
question: str | None = None
|
||||||
|
answer: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
annotation_reply: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationReplyStatusQuery(BaseModel):
|
||||||
|
action: Literal["enable", "disable"]
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationFilePayload(BaseModel):
|
||||||
|
message_id: str = Field(..., description="Message ID")
|
||||||
|
|
||||||
|
@field_validator("message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_message_id(cls, value: str) -> str:
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
def reg(model: type[BaseModel]) -> None:
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
reg(AnnotationReplyPayload)
|
||||||
|
reg(AnnotationSettingUpdatePayload)
|
||||||
|
reg(AnnotationListQuery)
|
||||||
|
reg(CreateAnnotationPayload)
|
||||||
|
reg(UpdateAnnotationPayload)
|
||||||
|
reg(AnnotationReplyStatusQuery)
|
||||||
|
reg(AnnotationFilePayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||||
class AnnotationReplyActionApi(Resource):
|
class AnnotationReplyActionApi(Resource):
|
||||||
@console_ns.doc("annotation_reply_action")
|
@console_ns.doc("annotation_reply_action")
|
||||||
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AnnotationReplyActionRequest",
|
|
||||||
{
|
|
||||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
|
||||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
|
|
||||||
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Action completed successfully")
|
@console_ns.response(200, "Action completed successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -46,15 +106,9 @@ class AnnotationReplyActionApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = (
|
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("score_threshold", required=True, type=float, location="json")
|
|
||||||
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
|
||||||
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
if action == "enable":
|
if action == "enable":
|
||||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||||
elif action == "disable":
|
elif action == "disable":
|
||||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||||
return result, 200
|
return result, 200
|
||||||
|
|
@ -82,16 +136,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||||
@console_ns.doc("update_annotation_setting")
|
@console_ns.doc("update_annotation_setting")
|
||||||
@console_ns.doc(description="Update annotation settings for an app")
|
@console_ns.doc(description="Update annotation settings for an app")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AnnotationSettingUpdateRequest",
|
|
||||||
{
|
|
||||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
|
||||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
|
|
||||||
"embedding_model_name": fields.String(required=True, description="Embedding model"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Settings updated successfully")
|
@console_ns.response(200, "Settings updated successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -102,10 +147,9 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_setting_id = str(annotation_setting_id)
|
annotation_setting_id = str(annotation_setting_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||||
return result, 200
|
return result, 200
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -142,12 +186,7 @@ class AnnotationApi(Resource):
|
||||||
@console_ns.doc("list_annotations")
|
@console_ns.doc("list_annotations")
|
||||||
@console_ns.doc(description="Get annotations for an app with pagination")
|
@console_ns.doc(description="Get annotations for an app with pagination")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
|
||||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
|
||||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Annotations retrieved successfully")
|
@console_ns.response(200, "Annotations retrieved successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -155,9 +194,10 @@ class AnnotationApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
page = request.args.get("page", default=1, type=int)
|
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
page = args.page
|
||||||
keyword = request.args.get("keyword", default="", type=str)
|
limit = args.limit
|
||||||
|
keyword = args.keyword
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||||
|
|
@ -173,18 +213,7 @@ class AnnotationApi(Resource):
|
||||||
@console_ns.doc("create_annotation")
|
@console_ns.doc("create_annotation")
|
||||||
@console_ns.doc(description="Create a new annotation for an app")
|
@console_ns.doc(description="Create a new annotation for an app")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CreateAnnotationRequest",
|
|
||||||
{
|
|
||||||
"message_id": fields.String(description="Message ID (optional)"),
|
|
||||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
|
||||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
|
||||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
|
||||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -195,16 +224,9 @@ class AnnotationApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = (
|
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
data = args.model_dump(exclude_none=True)
|
||||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||||
.add_argument("question", required=False, type=str, location="json")
|
|
||||||
.add_argument("answer", required=False, type=str, location="json")
|
|
||||||
.add_argument("content", required=False, type=str, location="json")
|
|
||||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -237,7 +259,7 @@ class AnnotationApi(Resource):
|
||||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||||
class AnnotationExportApi(Resource):
|
class AnnotationExportApi(Resource):
|
||||||
@console_ns.doc("export_annotations")
|
@console_ns.doc("export_annotations")
|
||||||
@console_ns.doc(description="Export all annotations for an app")
|
@console_ns.doc(description="Export all annotations for an app with CSV injection protection")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
|
|
@ -252,15 +274,14 @@ class AnnotationExportApi(Resource):
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
response_data = {"data": marshal(annotation_list, annotation_fields)}
|
||||||
return response, 200
|
|
||||||
|
|
||||||
|
# Create response with secure headers for CSV export
|
||||||
|
response = make_response(response_data, 200)
|
||||||
|
response.headers["Content-Type"] = "application/json; charset=utf-8"
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
|
||||||
parser = (
|
return response
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("question", required=True, type=str, location="json")
|
|
||||||
.add_argument("answer", required=True, type=str, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||||
|
|
@ -271,7 +292,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||||
@console_ns.response(204, "Annotation deleted successfully")
|
@console_ns.response(204, "Annotation deleted successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -281,8 +302,10 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||||
def post(self, app_id, annotation_id):
|
def post(self, app_id, annotation_id):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
args = parser.parse_args()
|
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||||
|
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||||
|
)
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -299,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||||
class AnnotationBatchImportApi(Resource):
|
class AnnotationBatchImportApi(Resource):
|
||||||
@console_ns.doc("batch_import_annotations")
|
@console_ns.doc("batch_import_annotations")
|
||||||
@console_ns.doc(description="Batch import annotations from CSV file")
|
@console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.response(200, "Batch import started successfully")
|
@console_ns.response(200, "Batch import started successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.response(400, "No file uploaded or too many files")
|
@console_ns.response(400, "No file uploaded or too many files")
|
||||||
|
@console_ns.response(413, "File too large")
|
||||||
|
@console_ns.response(429, "Too many requests or concurrent imports")
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@annotation_import_rate_limit
|
||||||
|
@annotation_import_concurrency_limit
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
|
|
||||||
# check file
|
# check file
|
||||||
if "file" not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
@ -320,9 +350,27 @@ class AnnotationBatchImportApi(Resource):
|
||||||
|
|
||||||
# get file from request
|
# get file from request
|
||||||
file = request.files["file"]
|
file = request.files["file"]
|
||||||
|
|
||||||
# check file type
|
# check file type
|
||||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||||
|
|
||||||
|
# Check file size before processing
|
||||||
|
file.seek(0, 2) # Seek to end of file
|
||||||
|
file_size = file.tell()
|
||||||
|
file.seek(0) # Reset to beginning
|
||||||
|
|
||||||
|
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
|
||||||
|
if file_size > max_size_bytes:
|
||||||
|
abort(
|
||||||
|
413,
|
||||||
|
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
|
||||||
|
f"Please reduce the file size and try again.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_size == 0:
|
||||||
|
raise ValueError("The uploaded file is empty")
|
||||||
|
|
||||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import BadRequest, abort
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
|
@ -28,7 +31,6 @@ from fields.app_fields import (
|
||||||
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
||||||
from libs.helper import AppIconUrlField, TimestampField
|
from libs.helper import AppIconUrlField, TimestampField
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from libs.validators import validate_description_length
|
|
||||||
from models import App, Workflow
|
from models import App, Workflow
|
||||||
from services.app_dsl_service import AppDslService, ImportMode
|
from services.app_dsl_service import AppDslService, ImportMode
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
@ -36,6 +38,116 @@ from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class AppListQuery(BaseModel):
|
||||||
|
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||||
|
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||||
|
default="all", description="App mode filter"
|
||||||
|
)
|
||||||
|
name: str | None = Field(default=None, description="Filter by app name")
|
||||||
|
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||||
|
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||||
|
|
||||||
|
@field_validator("tag_ids", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||||
|
elif isinstance(value, list):
|
||||||
|
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||||
|
else:
|
||||||
|
raise TypeError("Unsupported tag_ids type.")
|
||||||
|
|
||||||
|
if not items:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return [str(uuid.UUID(item)) for item in items]
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAppPayload(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1, description="App name")
|
||||||
|
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||||
|
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||||
|
icon_type: str | None = Field(default=None, description="Icon type")
|
||||||
|
icon: str | None = Field(default=None, description="Icon")
|
||||||
|
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateAppPayload(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1, description="App name")
|
||||||
|
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||||
|
icon_type: str | None = Field(default=None, description="Icon type")
|
||||||
|
icon: str | None = Field(default=None, description="Icon")
|
||||||
|
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||||
|
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||||
|
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||||
|
|
||||||
|
|
||||||
|
class CopyAppPayload(BaseModel):
|
||||||
|
name: str | None = Field(default=None, description="Name for the copied app")
|
||||||
|
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
|
||||||
|
icon_type: str | None = Field(default=None, description="Icon type")
|
||||||
|
icon: str | None = Field(default=None, description="Icon")
|
||||||
|
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||||
|
|
||||||
|
|
||||||
|
class AppExportQuery(BaseModel):
|
||||||
|
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||||
|
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||||
|
|
||||||
|
|
||||||
|
class AppNamePayload(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1, description="Name to check")
|
||||||
|
|
||||||
|
|
||||||
|
class AppIconPayload(BaseModel):
|
||||||
|
icon: str | None = Field(default=None, description="Icon data")
|
||||||
|
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||||
|
|
||||||
|
|
||||||
|
class AppSiteStatusPayload(BaseModel):
|
||||||
|
enable_site: bool = Field(..., description="Enable or disable site")
|
||||||
|
|
||||||
|
|
||||||
|
class AppApiStatusPayload(BaseModel):
|
||||||
|
enable_api: bool = Field(..., description="Enable or disable API")
|
||||||
|
|
||||||
|
|
||||||
|
class AppTracePayload(BaseModel):
|
||||||
|
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||||
|
tracing_provider: str | None = Field(default=None, description="Tracing provider")
|
||||||
|
|
||||||
|
@field_validator("tracing_provider")
|
||||||
|
@classmethod
|
||||||
|
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
|
||||||
|
if info.data.get("enabled") and not value:
|
||||||
|
raise ValueError("tracing_provider is required when enabled is True")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def reg(cls: type[BaseModel]):
|
||||||
|
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
reg(AppListQuery)
|
||||||
|
reg(CreateAppPayload)
|
||||||
|
reg(UpdateAppPayload)
|
||||||
|
reg(CopyAppPayload)
|
||||||
|
reg(AppExportQuery)
|
||||||
|
reg(AppNamePayload)
|
||||||
|
reg(AppIconPayload)
|
||||||
|
reg(AppSiteStatusPayload)
|
||||||
|
reg(AppApiStatusPayload)
|
||||||
|
reg(AppTracePayload)
|
||||||
|
|
||||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||||
# Register base models first
|
# Register base models first
|
||||||
|
|
@ -147,22 +259,7 @@ app_pagination_model = console_ns.model(
|
||||||
class AppListApi(Resource):
|
class AppListApi(Resource):
|
||||||
@console_ns.doc("list_apps")
|
@console_ns.doc("list_apps")
|
||||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
|
||||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
|
||||||
.add_argument(
|
|
||||||
"mode",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
|
|
||||||
default="all",
|
|
||||||
help="App mode filter",
|
|
||||||
)
|
|
||||||
.add_argument("name", type=str, location="args", help="Filter by app name")
|
|
||||||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
|
||||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Success", app_pagination_model)
|
@console_ns.response(200, "Success", app_pagination_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -172,42 +269,12 @@ class AppListApi(Resource):
|
||||||
"""Get app list"""
|
"""Get app list"""
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
def uuid_list(value):
|
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
try:
|
args_dict = args.model_dump()
|
||||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
|
||||||
except ValueError:
|
|
||||||
abort(400, message="Invalid UUID format in tag_ids.")
|
|
||||||
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
|
||||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
.add_argument(
|
|
||||||
"mode",
|
|
||||||
type=str,
|
|
||||||
choices=[
|
|
||||||
"completion",
|
|
||||||
"chat",
|
|
||||||
"advanced-chat",
|
|
||||||
"workflow",
|
|
||||||
"agent-chat",
|
|
||||||
"channel",
|
|
||||||
"all",
|
|
||||||
],
|
|
||||||
default="all",
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
.add_argument("name", type=str, location="args", required=False)
|
|
||||||
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
|
||||||
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# get app list
|
# get app list
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
|
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||||
if not app_pagination:
|
if not app_pagination:
|
||||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||||
|
|
||||||
|
|
@ -242,10 +309,13 @@ class AppListApi(Resource):
|
||||||
NodeType.TRIGGER_PLUGIN,
|
NodeType.TRIGGER_PLUGIN,
|
||||||
}
|
}
|
||||||
for workflow in draft_workflows:
|
for workflow in draft_workflows:
|
||||||
for _, node_data in workflow.walk_nodes():
|
try:
|
||||||
if node_data.get("type") in trigger_node_types:
|
for _, node_data in workflow.walk_nodes():
|
||||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
if node_data.get("type") in trigger_node_types:
|
||||||
break
|
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
for app in app_pagination.items:
|
for app in app_pagination.items:
|
||||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||||
|
|
@ -254,19 +324,7 @@ class AppListApi(Resource):
|
||||||
|
|
||||||
@console_ns.doc("create_app")
|
@console_ns.doc("create_app")
|
||||||
@console_ns.doc(description="Create a new application")
|
@console_ns.doc(description="Create a new application")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CreateAppRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="App name"),
|
|
||||||
"description": fields.String(description="App description (max 400 chars)"),
|
|
||||||
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
|
|
||||||
"icon_type": fields.String(description="Icon type"),
|
|
||||||
"icon": fields.String(description="Icon"),
|
|
||||||
"icon_background": fields.String(description="Icon background color"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
|
|
@ -279,22 +337,10 @@ class AppListApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Create app"""
|
"""Create app"""
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("name", type=str, required=True, location="json")
|
|
||||||
.add_argument("description", type=validate_description_length, location="json")
|
|
||||||
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
|
||||||
.add_argument("icon_type", type=str, location="json")
|
|
||||||
.add_argument("icon", type=str, location="json")
|
|
||||||
.add_argument("icon_background", type=str, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if "mode" not in args or args["mode"] is None:
|
|
||||||
raise BadRequest("mode is required")
|
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app = app_service.create_app(current_tenant_id, args, current_user)
|
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||||
|
|
||||||
return app, 201
|
return app, 201
|
||||||
|
|
||||||
|
|
@ -326,20 +372,7 @@ class AppApi(Resource):
|
||||||
@console_ns.doc("update_app")
|
@console_ns.doc("update_app")
|
||||||
@console_ns.doc(description="Update application details")
|
@console_ns.doc(description="Update application details")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"UpdateAppRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="App name"),
|
|
||||||
"description": fields.String(description="App description (max 400 chars)"),
|
|
||||||
"icon_type": fields.String(description="Icon type"),
|
|
||||||
"icon": fields.String(description="Icon"),
|
|
||||||
"icon_background": fields.String(description="Icon background color"),
|
|
||||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
|
||||||
"max_active_requests": fields.Integer(description="Maximum active requests"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
|
|
@ -351,28 +384,18 @@ class AppApi(Resource):
|
||||||
@marshal_with(app_detail_with_site_model)
|
@marshal_with(app_detail_with_site_model)
|
||||||
def put(self, app_model):
|
def put(self, app_model):
|
||||||
"""Update app"""
|
"""Update app"""
|
||||||
parser = (
|
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("description", type=validate_description_length, location="json")
|
|
||||||
.add_argument("icon_type", type=str, location="json")
|
|
||||||
.add_argument("icon", type=str, location="json")
|
|
||||||
.add_argument("icon_background", type=str, location="json")
|
|
||||||
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
|
||||||
.add_argument("max_active_requests", type=int, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
|
|
||||||
args_dict: AppService.ArgsDict = {
|
args_dict: AppService.ArgsDict = {
|
||||||
"name": args["name"],
|
"name": args.name,
|
||||||
"description": args.get("description", ""),
|
"description": args.description or "",
|
||||||
"icon_type": args.get("icon_type", ""),
|
"icon_type": args.icon_type or "",
|
||||||
"icon": args.get("icon", ""),
|
"icon": args.icon or "",
|
||||||
"icon_background": args.get("icon_background", ""),
|
"icon_background": args.icon_background or "",
|
||||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||||
"max_active_requests": args.get("max_active_requests", 0),
|
"max_active_requests": args.max_active_requests or 0,
|
||||||
}
|
}
|
||||||
app_model = app_service.update_app(app_model, args_dict)
|
app_model = app_service.update_app(app_model, args_dict)
|
||||||
|
|
||||||
|
|
@ -401,18 +424,7 @@ class AppCopyApi(Resource):
|
||||||
@console_ns.doc("copy_app")
|
@console_ns.doc("copy_app")
|
||||||
@console_ns.doc(description="Create a copy of an existing application")
|
@console_ns.doc(description="Create a copy of an existing application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CopyAppRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(description="Name for the copied app"),
|
|
||||||
"description": fields.String(description="Description for the copied app"),
|
|
||||||
"icon_type": fields.String(description="Icon type"),
|
|
||||||
"icon": fields.String(description="Icon"),
|
|
||||||
"icon_background": fields.String(description="Icon background color"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -426,15 +438,7 @@ class AppCopyApi(Resource):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("name", type=str, location="json")
|
|
||||||
.add_argument("description", type=validate_description_length, location="json")
|
|
||||||
.add_argument("icon_type", type=str, location="json")
|
|
||||||
.add_argument("icon", type=str, location="json")
|
|
||||||
.add_argument("icon_background", type=str, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
|
|
@ -443,11 +447,11 @@ class AppCopyApi(Resource):
|
||||||
account=current_user,
|
account=current_user,
|
||||||
import_mode=ImportMode.YAML_CONTENT,
|
import_mode=ImportMode.YAML_CONTENT,
|
||||||
yaml_content=yaml_content,
|
yaml_content=yaml_content,
|
||||||
name=args.get("name"),
|
name=args.name,
|
||||||
description=args.get("description"),
|
description=args.description,
|
||||||
icon_type=args.get("icon_type"),
|
icon_type=args.icon_type,
|
||||||
icon=args.get("icon"),
|
icon=args.icon,
|
||||||
icon_background=args.get("icon_background"),
|
icon_background=args.icon_background,
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
@ -462,11 +466,7 @@ class AppExportApi(Resource):
|
||||||
@console_ns.doc("export_app")
|
@console_ns.doc("export_app")
|
||||||
@console_ns.doc(description="Export application configuration as DSL")
|
@console_ns.doc(description="Export application configuration as DSL")
|
||||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
|
||||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"App exported successfully",
|
"App exported successfully",
|
||||||
|
|
@ -480,30 +480,23 @@ class AppExportApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Export app"""
|
"""Export app"""
|
||||||
# Add include_secret params
|
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
|
||||||
.add_argument("workflow_id", type=str, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"data": AppDslService.export_dsl(
|
"data": AppDslService.export_dsl(
|
||||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
app_model=app_model,
|
||||||
|
include_secret=args.include_secret,
|
||||||
|
workflow_id=args.workflow_id,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||||
class AppNameApi(Resource):
|
class AppNameApi(Resource):
|
||||||
@console_ns.doc("check_app_name")
|
@console_ns.doc("check_app_name")
|
||||||
@console_ns.doc(description="Check if app name is available")
|
@console_ns.doc(description="Check if app name is available")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||||
@console_ns.response(200, "Name availability checked")
|
@console_ns.response(200, "Name availability checked")
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -512,10 +505,10 @@ class AppNameApi(Resource):
|
||||||
@marshal_with(app_detail_model)
|
@marshal_with(app_detail_model)
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
args = parser.parse_args()
|
args = AppNamePayload.model_validate(console_ns.payload)
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_name(app_model, args["name"])
|
app_model = app_service.update_app_name(app_model, args.name)
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
@ -525,16 +518,7 @@ class AppIconApi(Resource):
|
||||||
@console_ns.doc("update_app_icon")
|
@console_ns.doc("update_app_icon")
|
||||||
@console_ns.doc(description="Update application icon")
|
@console_ns.doc(description="Update application icon")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AppIconRequest",
|
|
||||||
{
|
|
||||||
"icon": fields.String(required=True, description="Icon data"),
|
|
||||||
"icon_type": fields.String(description="Icon type"),
|
|
||||||
"icon_background": fields.String(description="Icon background color"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Icon updated successfully")
|
@console_ns.response(200, "Icon updated successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -544,15 +528,10 @@ class AppIconApi(Resource):
|
||||||
@marshal_with(app_detail_model)
|
@marshal_with(app_detail_model)
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = (
|
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("icon", type=str, location="json")
|
|
||||||
.add_argument("icon_background", type=str, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
@ -562,11 +541,7 @@ class AppSiteStatus(Resource):
|
||||||
@console_ns.doc("update_app_site_status")
|
@console_ns.doc("update_app_site_status")
|
||||||
@console_ns.doc(description="Enable or disable app site")
|
@console_ns.doc(description="Enable or disable app site")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -576,11 +551,10 @@ class AppSiteStatus(Resource):
|
||||||
@marshal_with(app_detail_model)
|
@marshal_with(app_detail_model)
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
@ -590,11 +564,7 @@ class AppApiStatus(Resource):
|
||||||
@console_ns.doc("update_app_api_status")
|
@console_ns.doc("update_app_api_status")
|
||||||
@console_ns.doc(description="Enable or disable app API")
|
@console_ns.doc(description="Enable or disable app API")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -604,11 +574,10 @@ class AppApiStatus(Resource):
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_detail_model)
|
@marshal_with(app_detail_model)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
@ -631,15 +600,7 @@ class AppTraceApi(Resource):
|
||||||
@console_ns.doc("update_app_trace")
|
@console_ns.doc("update_app_trace")
|
||||||
@console_ns.doc(description="Update app tracing configuration")
|
@console_ns.doc(description="Update app tracing configuration")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AppTraceRequest",
|
|
||||||
{
|
|
||||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
|
||||||
"tracing_provider": fields.String(required=True, description="Tracing provider"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Trace configuration updated successfully")
|
@console_ns.response(200, "Trace configuration updated successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -648,17 +609,12 @@ class AppTraceApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
# add app trace
|
# add app trace
|
||||||
parser = (
|
args = AppTracePayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("enabled", type=bool, required=True, location="json")
|
|
||||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
OpsTraceManager.update_app_tracing_config(
|
OpsTraceManager.update_app_tracing_config(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
enabled=args["enabled"],
|
enabled=args.enabled,
|
||||||
tracing_provider=args["tracing_provider"],
|
tracing_provider=args.tracing_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
|
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
|
||||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = (
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("mode", type=str, required=True, location="json")
|
|
||||||
.add_argument("yaml_content", type=str, location="json")
|
class AppImportPayload(BaseModel):
|
||||||
.add_argument("yaml_url", type=str, location="json")
|
mode: str = Field(..., description="Import mode")
|
||||||
.add_argument("name", type=str, location="json")
|
yaml_content: str | None = None
|
||||||
.add_argument("description", type=str, location="json")
|
yaml_url: str | None = None
|
||||||
.add_argument("icon_type", type=str, location="json")
|
name: str | None = None
|
||||||
.add_argument("icon", type=str, location="json")
|
description: str | None = None
|
||||||
.add_argument("icon_background", type=str, location="json")
|
icon_type: str | None = None
|
||||||
.add_argument("app_id", type=str, location="json")
|
icon: str | None = None
|
||||||
|
icon_background: str | None = None
|
||||||
|
app_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/imports")
|
@console_ns.route("/apps/imports")
|
||||||
class AppImportApi(Resource):
|
class AppImportApi(Resource):
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -61,7 +68,7 @@ class AppImportApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
args = parser.parse_args()
|
args = AppImportPayload.model_validate(console_ns.payload)
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|
@ -70,15 +77,15 @@ class AppImportApi(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
result = import_service.import_app(
|
result = import_service.import_app(
|
||||||
account=account,
|
account=account,
|
||||||
import_mode=args["mode"],
|
import_mode=args.mode,
|
||||||
yaml_content=args.get("yaml_content"),
|
yaml_content=args.yaml_content,
|
||||||
yaml_url=args.get("yaml_url"),
|
yaml_url=args.yaml_url,
|
||||||
name=args.get("name"),
|
name=args.name,
|
||||||
description=args.get("description"),
|
description=args.description,
|
||||||
icon_type=args.get("icon_type"),
|
icon_type=args.icon_type,
|
||||||
icon=args.get("icon"),
|
icon=args.icon,
|
||||||
icon_background=args.get("icon_background"),
|
icon_background=args.icon_background,
|
||||||
app_id=args.get("app_id"),
|
app_id=args.app_id,
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
|
@ -32,6 +33,27 @@ from services.errors.audio import (
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechPayload(BaseModel):
|
||||||
|
message_id: str | None = Field(default=None, description="Message ID")
|
||||||
|
text: str = Field(..., description="Text to convert")
|
||||||
|
voice: str | None = Field(default=None, description="Voice name")
|
||||||
|
streaming: bool | None = Field(default=None, description="Whether to stream audio")
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechVoiceQuery(BaseModel):
|
||||||
|
language: str = Field(..., description="Language code")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
TextToSpeechVoiceQuery.__name__,
|
||||||
|
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||||
|
|
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
|
||||||
@console_ns.doc("chat_message_text_to_speech")
|
@console_ns.doc("chat_message_text_to_speech")
|
||||||
@console_ns.doc(description="Convert text to speech for chat messages")
|
@console_ns.doc(description="Convert text to speech for chat messages")
|
||||||
@console_ns.doc(params={"app_id": "App ID"})
|
@console_ns.doc(params={"app_id": "App ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"TextToSpeechRequest",
|
|
||||||
{
|
|
||||||
"message_id": fields.String(description="Message ID"),
|
|
||||||
"text": fields.String(required=True, description="Text to convert to speech"),
|
|
||||||
"voice": fields.String(description="Voice to use for TTS"),
|
|
||||||
"streaming": fields.Boolean(description="Whether to stream the audio"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Text to speech conversion successful")
|
@console_ns.response(200, "Text to speech conversion successful")
|
||||||
@console_ns.response(400, "Bad request - Invalid parameters")
|
@console_ns.response(400, "Bad request - Invalid parameters")
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
|
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
try:
|
try:
|
||||||
parser = (
|
payload = TextToSpeechPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("message_id", type=str, location="json")
|
|
||||||
.add_argument("text", type=str, location="json")
|
|
||||||
.add_argument("voice", type=str, location="json")
|
|
||||||
.add_argument("streaming", type=bool, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
message_id = args.get("message_id", None)
|
|
||||||
text = args.get("text", None)
|
|
||||||
voice = args.get("voice", None)
|
|
||||||
|
|
||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(
|
||||||
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
|
app_model=app_model,
|
||||||
|
text=payload.text,
|
||||||
|
voice=payload.voice,
|
||||||
|
message_id=payload.message_id,
|
||||||
|
is_draft=True,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
|
@ -159,9 +164,7 @@ class TextModesApi(Resource):
|
||||||
@console_ns.doc("get_text_to_speech_voices")
|
@console_ns.doc("get_text_to_speech_voices")
|
||||||
@console_ns.doc(description="Get available TTS voices for a specific language")
|
@console_ns.doc(description="Get available TTS voices for a specific language")
|
||||||
@console_ns.doc(params={"app_id": "App ID"})
|
@console_ns.doc(params={"app_id": "App ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
|
||||||
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
||||||
)
|
)
|
||||||
|
|
@ -172,12 +175,11 @@ class TextModesApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
response = AudioService.transcript_tts_voices(
|
response = AudioService.transcript_tts_voices(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
language=args["language"],
|
language=args.language,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
|
@ -35,6 +37,41 @@ from services.app_task_service import AppTaskService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessagePayload(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||||
|
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||||
|
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||||
|
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionMessagePayload(BaseMessagePayload):
|
||||||
|
query: str = Field(default="", description="Query text")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessagePayload(BaseMessagePayload):
|
||||||
|
query: str = Field(..., description="User query")
|
||||||
|
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||||
|
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||||
|
|
||||||
|
@field_validator("conversation_id", "parent_message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_uuid(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
CompletionMessagePayload.__name__,
|
||||||
|
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# define completion message api for user
|
# define completion message api for user
|
||||||
|
|
@ -43,19 +80,7 @@ class CompletionMessageApi(Resource):
|
||||||
@console_ns.doc("create_completion_message")
|
@console_ns.doc("create_completion_message")
|
||||||
@console_ns.doc(description="Generate completion message for debugging")
|
@console_ns.doc(description="Generate completion message for debugging")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CompletionMessageRequest",
|
|
||||||
{
|
|
||||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
|
||||||
"query": fields.String(description="Query text", default=""),
|
|
||||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
|
||||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
|
||||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
|
||||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Completion generated successfully")
|
@console_ns.response(200, "Completion generated successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@console_ns.response(404, "App not found")
|
@console_ns.response(404, "App not found")
|
||||||
|
|
@ -64,18 +89,10 @@ class CompletionMessageApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = (
|
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||||
.add_argument("inputs", type=dict, required=True, location="json")
|
|
||||||
.add_argument("query", type=str, location="json", default="")
|
|
||||||
.add_argument("files", type=list, required=False, location="json")
|
|
||||||
.add_argument("model_config", type=dict, required=True, location="json")
|
|
||||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
|
||||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
streaming = args["response_mode"] != "blocking"
|
streaming = args_model.response_mode != "blocking"
|
||||||
args["auto_generate_name"] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -137,21 +154,7 @@ class ChatMessageApi(Resource):
|
||||||
@console_ns.doc("create_chat_message")
|
@console_ns.doc("create_chat_message")
|
||||||
@console_ns.doc(description="Generate chat message for debugging")
|
@console_ns.doc(description="Generate chat message for debugging")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"ChatMessageRequest",
|
|
||||||
{
|
|
||||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
|
||||||
"query": fields.String(required=True, description="User query"),
|
|
||||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
|
||||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
|
||||||
"conversation_id": fields.String(description="Conversation ID"),
|
|
||||||
"parent_message_id": fields.String(description="Parent message ID"),
|
|
||||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
|
||||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Chat message generated successfully")
|
@console_ns.response(200, "Chat message generated successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@console_ns.response(404, "App or conversation not found")
|
@console_ns.response(404, "App or conversation not found")
|
||||||
|
|
@ -161,20 +164,10 @@ class ChatMessageApi(Resource):
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = (
|
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||||
.add_argument("inputs", type=dict, required=True, location="json")
|
|
||||||
.add_argument("query", type=str, required=True, location="json")
|
|
||||||
.add_argument("files", type=list, required=False, location="json")
|
|
||||||
.add_argument("model_config", type=dict, required=True, location="json")
|
|
||||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
|
||||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
|
||||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
|
||||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
streaming = args["response_mode"] != "blocking"
|
streaming = args_model.response_mode != "blocking"
|
||||||
args["auto_generate_name"] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
external_trace_id = get_external_trace_id(request)
|
external_trace_id = get_external_trace_id(request)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import abort
|
from flask import abort, request
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with
|
||||||
from flask_restx.inputs import int_range
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func, or_
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
@ -14,13 +16,53 @@ from extensions.ext_database import db
|
||||||
from fields.conversation_fields import MessageTextField
|
from fields.conversation_fields import MessageTextField
|
||||||
from fields.raws import FilesContainedField
|
from fields.raws import FilesContainedField
|
||||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||||
from libs.helper import DatetimeString, TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConversationQuery(BaseModel):
|
||||||
|
keyword: str | None = Field(default=None, description="Search keyword")
|
||||||
|
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||||
|
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||||
|
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
|
||||||
|
default="all", description="Annotation status filter"
|
||||||
|
)
|
||||||
|
page: int = Field(default=1, ge=1, le=99999, description="Page number")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||||
|
|
||||||
|
@field_validator("start", "end", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def blank_to_none(cls, value: str | None) -> str | None:
|
||||||
|
if value == "":
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionConversationQuery(BaseConversationQuery):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChatConversationQuery(BaseConversationQuery):
|
||||||
|
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||||
|
default="-updated_at", description="Sort field and direction"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
CompletionConversationQuery.__name__,
|
||||||
|
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
ChatConversationQuery.__name__,
|
||||||
|
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||||
# Register in dependency order: base models first, then dependent models
|
# Register in dependency order: base models first, then dependent models
|
||||||
|
|
||||||
|
|
@ -283,22 +325,7 @@ class CompletionConversationApi(Resource):
|
||||||
@console_ns.doc("list_completion_conversations")
|
@console_ns.doc("list_completion_conversations")
|
||||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
|
||||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
|
||||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
|
||||||
.add_argument(
|
|
||||||
"annotation_status",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
choices=["annotated", "not_annotated", "all"],
|
|
||||||
default="all",
|
|
||||||
help="Annotation status filter",
|
|
||||||
)
|
|
||||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
|
||||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -309,32 +336,17 @@ class CompletionConversationApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = (
|
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("keyword", type=str, location="args")
|
|
||||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument(
|
|
||||||
"annotation_status",
|
|
||||||
type=str,
|
|
||||||
choices=["annotated", "not_annotated", "all"],
|
|
||||||
default="all",
|
|
||||||
location="args",
|
|
||||||
)
|
|
||||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
query = sa.select(Conversation).where(
|
query = sa.select(Conversation).where(
|
||||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||||
)
|
)
|
||||||
|
|
||||||
if args["keyword"]:
|
if args.keyword:
|
||||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||||
or_(
|
or_(
|
||||||
Message.query.ilike(f"%{args['keyword']}%"),
|
Message.query.ilike(f"%{args.keyword}%"),
|
||||||
Message.answer.ilike(f"%{args['keyword']}%"),
|
Message.answer.ilike(f"%{args.keyword}%"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -342,7 +354,7 @@ class CompletionConversationApi(Resource):
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -354,11 +366,11 @@ class CompletionConversationApi(Resource):
|
||||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||||
|
|
||||||
# FIXME, the type ignore in this file
|
# FIXME, the type ignore in this file
|
||||||
if args["annotation_status"] == "annotated":
|
if args.annotation_status == "annotated":
|
||||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||||
)
|
)
|
||||||
elif args["annotation_status"] == "not_annotated":
|
elif args.annotation_status == "not_annotated":
|
||||||
query = (
|
query = (
|
||||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||||
.group_by(Conversation.id)
|
.group_by(Conversation.id)
|
||||||
|
|
@ -367,7 +379,7 @@ class CompletionConversationApi(Resource):
|
||||||
|
|
||||||
query = query.order_by(Conversation.created_at.desc())
|
query = query.order_by(Conversation.created_at.desc())
|
||||||
|
|
||||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||||
|
|
||||||
return conversations
|
return conversations
|
||||||
|
|
||||||
|
|
@ -419,31 +431,7 @@ class ChatConversationApi(Resource):
|
||||||
@console_ns.doc("list_chat_conversations")
|
@console_ns.doc("list_chat_conversations")
|
||||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
|
||||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
|
||||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
|
||||||
.add_argument(
|
|
||||||
"annotation_status",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
choices=["annotated", "not_annotated", "all"],
|
|
||||||
default="all",
|
|
||||||
help="Annotation status filter",
|
|
||||||
)
|
|
||||||
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
|
|
||||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
|
||||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
|
||||||
.add_argument(
|
|
||||||
"sort_by",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
|
||||||
default="-updated_at",
|
|
||||||
help="Sort field and direction",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -454,31 +442,7 @@ class ChatConversationApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = (
|
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("keyword", type=str, location="args")
|
|
||||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument(
|
|
||||||
"annotation_status",
|
|
||||||
type=str,
|
|
||||||
choices=["annotated", "not_annotated", "all"],
|
|
||||||
default="all",
|
|
||||||
location="args",
|
|
||||||
)
|
|
||||||
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
|
||||||
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
.add_argument(
|
|
||||||
"sort_by",
|
|
||||||
type=str,
|
|
||||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
|
||||||
required=False,
|
|
||||||
default="-updated_at",
|
|
||||||
location="args",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
subquery = (
|
subquery = (
|
||||||
db.session.query(
|
db.session.query(
|
||||||
|
|
@ -490,8 +454,8 @@ class ChatConversationApi(Resource):
|
||||||
|
|
||||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||||
|
|
||||||
if args["keyword"]:
|
if args.keyword:
|
||||||
keyword_filter = f"%{args['keyword']}%"
|
keyword_filter = f"%{args.keyword}%"
|
||||||
query = (
|
query = (
|
||||||
query.join(
|
query.join(
|
||||||
Message,
|
Message,
|
||||||
|
|
@ -514,12 +478,12 @@ class ChatConversationApi(Resource):
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
if start_datetime_utc:
|
if start_datetime_utc:
|
||||||
match args["sort_by"]:
|
match args.sort_by:
|
||||||
case "updated_at" | "-updated_at":
|
case "updated_at" | "-updated_at":
|
||||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||||
case "created_at" | "-created_at" | _:
|
case "created_at" | "-created_at" | _:
|
||||||
|
|
@ -527,35 +491,27 @@ class ChatConversationApi(Resource):
|
||||||
|
|
||||||
if end_datetime_utc:
|
if end_datetime_utc:
|
||||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||||
match args["sort_by"]:
|
match args.sort_by:
|
||||||
case "updated_at" | "-updated_at":
|
case "updated_at" | "-updated_at":
|
||||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||||
case "created_at" | "-created_at" | _:
|
case "created_at" | "-created_at" | _:
|
||||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||||
|
|
||||||
if args["annotation_status"] == "annotated":
|
if args.annotation_status == "annotated":
|
||||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||||
)
|
)
|
||||||
elif args["annotation_status"] == "not_annotated":
|
elif args.annotation_status == "not_annotated":
|
||||||
query = (
|
query = (
|
||||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||||
.group_by(Conversation.id)
|
.group_by(Conversation.id)
|
||||||
.having(func.count(MessageAnnotation.id) == 0)
|
.having(func.count(MessageAnnotation.id) == 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
|
||||||
query = (
|
|
||||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
|
||||||
.join(Message, Message.conversation_id == Conversation.id)
|
|
||||||
.group_by(Conversation.id)
|
|
||||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
|
||||||
)
|
|
||||||
|
|
||||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||||
|
|
||||||
match args["sort_by"]:
|
match args.sort_by:
|
||||||
case "created_at":
|
case "created_at":
|
||||||
query = query.order_by(Conversation.created_at.asc())
|
query = query.order_by(Conversation.created_at.asc())
|
||||||
case "-created_at":
|
case "-created_at":
|
||||||
|
|
@ -567,7 +523,7 @@ class ChatConversationApi(Resource):
|
||||||
case _:
|
case _:
|
||||||
query = query.order_by(Conversation.created_at.desc())
|
query = query.order_by(Conversation.created_at.desc())
|
||||||
|
|
||||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||||
|
|
||||||
return conversations
|
return conversations
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
@ -14,6 +16,18 @@ from libs.login import login_required
|
||||||
from models import ConversationVariable
|
from models import ConversationVariable
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariablesQuery(BaseModel):
|
||||||
|
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
ConversationVariablesQuery.__name__,
|
||||||
|
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||||
# Register base model first
|
# Register base model first
|
||||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||||
|
|
@ -33,11 +47,7 @@ class ConversationVariablesApi(Resource):
|
||||||
@console_ns.doc("get_conversation_variables")
|
@console_ns.doc("get_conversation_variables")
|
||||||
@console_ns.doc(description="Get conversation variables for an application")
|
@console_ns.doc(description="Get conversation variables for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||||
console_ns.parser().add_argument(
|
|
||||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -45,18 +55,14 @@ class ConversationVariablesApi(Resource):
|
||||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||||
@marshal_with(paginated_conversation_variable_model)
|
@marshal_with(paginated_conversation_variable_model)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
select(ConversationVariable)
|
select(ConversationVariable)
|
||||||
.where(ConversationVariable.app_id == app_model.id)
|
.where(ConversationVariable.app_id == app_model.id)
|
||||||
.order_by(ConversationVariable.created_at)
|
.order_by(ConversationVariable.created_at)
|
||||||
)
|
)
|
||||||
if args["conversation_id"]:
|
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
|
||||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
|
||||||
else:
|
|
||||||
raise ValueError("conversation_id is required")
|
|
||||||
|
|
||||||
# NOTE: This is a temporary solution to avoid performance issues.
|
# NOTE: This is a temporary solution to avoid performance issues.
|
||||||
page = 1
|
page = 1
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
|
|
@ -21,21 +23,54 @@ from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App
|
from models import App
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class RuleGeneratePayload(BaseModel):
|
||||||
|
instruction: str = Field(..., description="Rule generation instruction")
|
||||||
|
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||||
|
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||||
|
|
||||||
|
|
||||||
|
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||||
|
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||||
|
|
||||||
|
|
||||||
|
class RuleStructuredOutputPayload(BaseModel):
|
||||||
|
instruction: str = Field(..., description="Structured output generation instruction")
|
||||||
|
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||||
|
|
||||||
|
|
||||||
|
class InstructionGeneratePayload(BaseModel):
|
||||||
|
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||||
|
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||||
|
current: str = Field(default="", description="Current instruction text")
|
||||||
|
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||||
|
instruction: str = Field(..., description="Instruction for generation")
|
||||||
|
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||||
|
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||||
|
|
||||||
|
|
||||||
|
class InstructionTemplatePayload(BaseModel):
|
||||||
|
type: str = Field(..., description="Instruction template type")
|
||||||
|
|
||||||
|
|
||||||
|
def reg(cls: type[BaseModel]):
|
||||||
|
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
reg(RuleGeneratePayload)
|
||||||
|
reg(RuleCodeGeneratePayload)
|
||||||
|
reg(RuleStructuredOutputPayload)
|
||||||
|
reg(InstructionGeneratePayload)
|
||||||
|
reg(InstructionTemplatePayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rule-generate")
|
@console_ns.route("/rule-generate")
|
||||||
class RuleGenerateApi(Resource):
|
class RuleGenerateApi(Resource):
|
||||||
@console_ns.doc("generate_rule_config")
|
@console_ns.doc("generate_rule_config")
|
||||||
@console_ns.doc(description="Generate rule configuration using LLM")
|
@console_ns.doc(description="Generate rule configuration using LLM")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"RuleGenerateRequest",
|
|
||||||
{
|
|
||||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
|
||||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
|
||||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Rule configuration generated successfully")
|
@console_ns.response(200, "Rule configuration generated successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@console_ns.response(402, "Provider quota exceeded")
|
@console_ns.response(402, "Provider quota exceeded")
|
||||||
|
|
@ -43,21 +78,15 @@ class RuleGenerateApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rules = LLMGenerator.generate_rule_config(
|
rules = LLMGenerator.generate_rule_config(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args.instruction,
|
||||||
model_config=args["model_config"],
|
model_config=args.model_config_data,
|
||||||
no_variable=args["no_variable"],
|
no_variable=args.no_variable,
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
|
@ -75,19 +104,7 @@ class RuleGenerateApi(Resource):
|
||||||
class RuleCodeGenerateApi(Resource):
|
class RuleCodeGenerateApi(Resource):
|
||||||
@console_ns.doc("generate_rule_code")
|
@console_ns.doc("generate_rule_code")
|
||||||
@console_ns.doc(description="Generate code rules using LLM")
|
@console_ns.doc(description="Generate code rules using LLM")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"RuleCodeGenerateRequest",
|
|
||||||
{
|
|
||||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
|
||||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
|
||||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
|
||||||
"code_language": fields.String(
|
|
||||||
default="javascript", description="Programming language for code generation"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Code rules generated successfully")
|
@console_ns.response(200, "Code rules generated successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@console_ns.response(402, "Provider quota exceeded")
|
@console_ns.response(402, "Provider quota exceeded")
|
||||||
|
|
@ -95,22 +112,15 @@ class RuleCodeGenerateApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
|
||||||
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code_result = LLMGenerator.generate_code(
|
code_result = LLMGenerator.generate_code(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args.instruction,
|
||||||
model_config=args["model_config"],
|
model_config=args.model_config_data,
|
||||||
code_language=args["code_language"],
|
code_language=args.code_language,
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
|
@ -128,15 +138,7 @@ class RuleCodeGenerateApi(Resource):
|
||||||
class RuleStructuredOutputGenerateApi(Resource):
|
class RuleStructuredOutputGenerateApi(Resource):
|
||||||
@console_ns.doc("generate_structured_output")
|
@console_ns.doc("generate_structured_output")
|
||||||
@console_ns.doc(description="Generate structured output rules using LLM")
|
@console_ns.doc(description="Generate structured output rules using LLM")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"StructuredOutputGenerateRequest",
|
|
||||||
{
|
|
||||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
|
||||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Structured output generated successfully")
|
@console_ns.response(200, "Structured output generated successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@console_ns.response(402, "Provider quota exceeded")
|
@console_ns.response(402, "Provider quota exceeded")
|
||||||
|
|
@ -144,19 +146,14 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
structured_output = LLMGenerator.generate_structured_output(
|
structured_output = LLMGenerator.generate_structured_output(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args.instruction,
|
||||||
model_config=args["model_config"],
|
model_config=args.model_config_data,
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
|
@ -174,20 +171,7 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||||
class InstructionGenerateApi(Resource):
|
class InstructionGenerateApi(Resource):
|
||||||
@console_ns.doc("generate_instruction")
|
@console_ns.doc("generate_instruction")
|
||||||
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"InstructionGenerateRequest",
|
|
||||||
{
|
|
||||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
|
||||||
"node_id": fields.String(description="Node ID for workflow context"),
|
|
||||||
"current": fields.String(description="Current instruction text"),
|
|
||||||
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
|
|
||||||
"instruction": fields.String(required=True, description="Instruction for generation"),
|
|
||||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
|
||||||
"ideal_output": fields.String(description="Expected ideal output"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Instruction generated successfully")
|
@console_ns.response(200, "Instruction generated successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
||||||
@console_ns.response(402, "Provider quota exceeded")
|
@console_ns.response(402, "Provider quota exceeded")
|
||||||
|
|
@ -195,79 +179,69 @@ class InstructionGenerateApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("flow_id", type=str, required=True, default="", location="json")
|
|
||||||
.add_argument("node_id", type=str, required=False, default="", location="json")
|
|
||||||
.add_argument("current", type=str, required=False, default="", location="json")
|
|
||||||
.add_argument("language", type=str, required=False, default="javascript", location="json")
|
|
||||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||||
code_provider: type[CodeNodeProvider] | None = next(
|
code_provider: type[CodeNodeProvider] | None = next(
|
||||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
(p for p in providers if p.is_accept_language(args.language)), None
|
||||||
)
|
)
|
||||||
code_template = code_provider.get_default_code() if code_provider else ""
|
code_template = code_provider.get_default_code() if code_provider else ""
|
||||||
try:
|
try:
|
||||||
# Generate from nothing for a workflow node
|
# Generate from nothing for a workflow node
|
||||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
if (args.current in (code_template, "")) and args.node_id != "":
|
||||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||||
if not app:
|
if not app:
|
||||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
return {"error": f"app {args.flow_id} not found"}, 400
|
||||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||||
if not workflow:
|
if not workflow:
|
||||||
return {"error": f"workflow {args['flow_id']} not found"}, 400
|
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||||
node = [node for node in nodes if node["id"] == args["node_id"]]
|
node = [node for node in nodes if node["id"] == args.node_id]
|
||||||
if len(node) == 0:
|
if len(node) == 0:
|
||||||
return {"error": f"node {args['node_id']} not found"}, 400
|
return {"error": f"node {args.node_id} not found"}, 400
|
||||||
node_type = node[0]["data"]["type"]
|
node_type = node[0]["data"]["type"]
|
||||||
match node_type:
|
match node_type:
|
||||||
case "llm":
|
case "llm":
|
||||||
return LLMGenerator.generate_rule_config(
|
return LLMGenerator.generate_rule_config(
|
||||||
current_tenant_id,
|
current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args.instruction,
|
||||||
model_config=args["model_config"],
|
model_config=args.model_config_data,
|
||||||
no_variable=True,
|
no_variable=True,
|
||||||
)
|
)
|
||||||
case "agent":
|
case "agent":
|
||||||
return LLMGenerator.generate_rule_config(
|
return LLMGenerator.generate_rule_config(
|
||||||
current_tenant_id,
|
current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args.instruction,
|
||||||
model_config=args["model_config"],
|
model_config=args.model_config_data,
|
||||||
no_variable=True,
|
no_variable=True,
|
||||||
)
|
)
|
||||||
case "code":
|
case "code":
|
||||||
return LLMGenerator.generate_code(
|
return LLMGenerator.generate_code(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args.instruction,
|
||||||
model_config=args["model_config"],
|
model_config=args.model_config_data,
|
||||||
code_language=args["language"],
|
code_language=args.language,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
return {"error": f"invalid node type: {node_type}"}
|
return {"error": f"invalid node type: {node_type}"}
|
||||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||||
return LLMGenerator.instruction_modify_legacy(
|
return LLMGenerator.instruction_modify_legacy(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
flow_id=args["flow_id"],
|
flow_id=args.flow_id,
|
||||||
current=args["current"],
|
current=args.current,
|
||||||
instruction=args["instruction"],
|
instruction=args.instruction,
|
||||||
model_config=args["model_config"],
|
model_config=args.model_config_data,
|
||||||
ideal_output=args["ideal_output"],
|
ideal_output=args.ideal_output,
|
||||||
)
|
)
|
||||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
if args.node_id != "" and args.current != "": # For workflow node
|
||||||
return LLMGenerator.instruction_modify_workflow(
|
return LLMGenerator.instruction_modify_workflow(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
flow_id=args["flow_id"],
|
flow_id=args.flow_id,
|
||||||
node_id=args["node_id"],
|
node_id=args.node_id,
|
||||||
current=args["current"],
|
current=args.current,
|
||||||
instruction=args["instruction"],
|
instruction=args.instruction,
|
||||||
model_config=args["model_config"],
|
model_config=args.model_config_data,
|
||||||
ideal_output=args["ideal_output"],
|
ideal_output=args.ideal_output,
|
||||||
workflow_service=WorkflowService(),
|
workflow_service=WorkflowService(),
|
||||||
)
|
)
|
||||||
return {"error": "incompatible parameters"}, 400
|
return {"error": "incompatible parameters"}, 400
|
||||||
|
|
@ -285,24 +259,15 @@ class InstructionGenerateApi(Resource):
|
||||||
class InstructionGenerationTemplateApi(Resource):
|
class InstructionGenerationTemplateApi(Resource):
|
||||||
@console_ns.doc("get_instruction_template")
|
@console_ns.doc("get_instruction_template")
|
||||||
@console_ns.doc(description="Get instruction generation template")
|
@console_ns.doc(description="Get instruction generation template")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"InstructionTemplateRequest",
|
|
||||||
{
|
|
||||||
"instruction": fields.String(required=True, description="Template instruction"),
|
|
||||||
"ideal_output": fields.String(description="Expected ideal output"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Template retrieved successfully")
|
@console_ns.response(200, "Template retrieved successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
|
args = InstructionTemplatePayload.model_validate(console_ns.payload)
|
||||||
args = parser.parse_args()
|
match args.type:
|
||||||
match args["type"]:
|
|
||||||
case "prompt":
|
case "prompt":
|
||||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
||||||
|
|
||||||
|
|
@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
|
||||||
|
|
||||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid type: {args['type']}")
|
raise ValueError(f"Invalid type: {args.type}")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import json
|
import json
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.model import AppMCPServer
|
from models.model import AppMCPServer
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||||
|
|
||||||
|
|
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
|
||||||
INACTIVE = "inactive"
|
INACTIVE = "inactive"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerCreatePayload(BaseModel):
|
||||||
|
description: str | None = Field(default=None, description="Server description")
|
||||||
|
parameters: dict = Field(..., description="Server parameters configuration")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerUpdatePayload(BaseModel):
|
||||||
|
id: str = Field(..., description="Server ID")
|
||||||
|
description: str | None = Field(default=None, description="Server description")
|
||||||
|
parameters: dict = Field(..., description="Server parameters configuration")
|
||||||
|
status: str | None = Field(default=None, description="Server status")
|
||||||
|
|
||||||
|
|
||||||
|
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||||
class AppMCPServerController(Resource):
|
class AppMCPServerController(Resource):
|
||||||
@console_ns.doc("get_app_mcp_server")
|
@console_ns.doc("get_app_mcp_server")
|
||||||
|
|
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
|
||||||
@console_ns.doc("create_app_mcp_server")
|
@console_ns.doc("create_app_mcp_server")
|
||||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"MCPServerCreateRequest",
|
|
||||||
{
|
|
||||||
"description": fields.String(description="Server description"),
|
|
||||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("description", type=str, required=False, location="json")
|
|
||||||
.add_argument("parameters", type=dict, required=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
description = args.get("description")
|
description = payload.description
|
||||||
if not description:
|
if not description:
|
||||||
description = app_model.description or ""
|
description = app_model.description or ""
|
||||||
|
|
||||||
server = AppMCPServer(
|
server = AppMCPServer(
|
||||||
name=app_model.name,
|
name=app_model.name,
|
||||||
description=description,
|
description=description,
|
||||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
parameters=json.dumps(payload.parameters, ensure_ascii=False),
|
||||||
status=AppMCPServerStatus.ACTIVE,
|
status=AppMCPServerStatus.ACTIVE,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
|
|
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
|
||||||
@console_ns.doc("update_app_mcp_server")
|
@console_ns.doc("update_app_mcp_server")
|
||||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"MCPServerUpdateRequest",
|
|
||||||
{
|
|
||||||
"id": fields.String(required=True, description="Server ID"),
|
|
||||||
"description": fields.String(description="Server description"),
|
|
||||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
|
||||||
"status": fields.String(description="Server status"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.response(404, "Server not found")
|
@console_ns.response(404, "Server not found")
|
||||||
|
|
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
|
||||||
@marshal_with(app_server_model)
|
@marshal_with(app_server_model)
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def put(self, app_model):
|
def put(self, app_model):
|
||||||
parser = (
|
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||||
.add_argument("id", type=str, required=True, location="json")
|
|
||||||
.add_argument("description", type=str, required=False, location="json")
|
|
||||||
.add_argument("parameters", type=dict, required=True, location="json")
|
|
||||||
.add_argument("status", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
|
||||||
if not server:
|
if not server:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
description = args.get("description")
|
description = payload.description
|
||||||
if description is None:
|
if description is None:
|
||||||
pass
|
pass
|
||||||
elif not description:
|
elif not description:
|
||||||
|
|
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
|
||||||
else:
|
else:
|
||||||
server.description = description
|
server.description = description
|
||||||
|
|
||||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||||
if args["status"]:
|
if payload.status:
|
||||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
if payload.status not in [status.value for status in AppMCPServerStatus]:
|
||||||
raise ValueError("Invalid status")
|
raise ValueError("Invalid status")
|
||||||
server.status = args["status"]
|
server.status = payload.status
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask import request
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import exists, select
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
|
@ -33,6 +35,68 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
||||||
from services.message_service import MessageService
|
from services.message_service import MessageService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessagesQuery(BaseModel):
|
||||||
|
conversation_id: str = Field(..., description="Conversation ID")
|
||||||
|
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||||
|
|
||||||
|
@field_validator("first_id", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def empty_to_none(cls, value: str | None) -> str | None:
|
||||||
|
if value == "":
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
@field_validator("conversation_id", "first_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_uuid(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFeedbackPayload(BaseModel):
|
||||||
|
message_id: str = Field(..., description="Message ID")
|
||||||
|
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||||
|
content: str | None = Field(default=None, description="Feedback content")
|
||||||
|
|
||||||
|
@field_validator("message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_message_id(cls, value: str) -> str:
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackExportQuery(BaseModel):
|
||||||
|
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
|
||||||
|
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
|
||||||
|
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
|
||||||
|
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
|
||||||
|
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
|
||||||
|
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
|
||||||
|
|
||||||
|
@field_validator("has_comment", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def parse_bool(cls, value: bool | str | None) -> bool | None:
|
||||||
|
if isinstance(value, bool) or value is None:
|
||||||
|
return value
|
||||||
|
lowered = value.lower()
|
||||||
|
if lowered in {"true", "1", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if lowered in {"false", "0", "no", "off"}:
|
||||||
|
return False
|
||||||
|
raise ValueError("has_comment must be a boolean value")
|
||||||
|
|
||||||
|
|
||||||
|
def reg(cls: type[BaseModel]):
|
||||||
|
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
reg(ChatMessagesQuery)
|
||||||
|
reg(MessageFeedbackPayload)
|
||||||
|
reg(FeedbackExportQuery)
|
||||||
|
|
||||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||||
# Register in dependency order: base models first, then dependent models
|
# Register in dependency order: base models first, then dependent models
|
||||||
|
|
@ -157,12 +221,7 @@ class ChatMessageListApi(Resource):
|
||||||
@console_ns.doc("list_chat_messages")
|
@console_ns.doc("list_chat_messages")
|
||||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
|
||||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
|
||||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||||
@console_ns.response(404, "Conversation not found")
|
@console_ns.response(404, "Conversation not found")
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -172,27 +231,21 @@ class ChatMessageListApi(Resource):
|
||||||
@marshal_with(message_infinite_scroll_pagination_model)
|
@marshal_with(message_infinite_scroll_pagination_model)
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
parser = (
|
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
|
||||||
.add_argument("first_id", type=uuid_value, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
conversation = (
|
conversation = (
|
||||||
db.session.query(Conversation)
|
db.session.query(Conversation)
|
||||||
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
if args["first_id"]:
|
if args.first_id:
|
||||||
first_message = (
|
first_message = (
|
||||||
db.session.query(Message)
|
db.session.query(Message)
|
||||||
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -207,7 +260,7 @@ class ChatMessageListApi(Resource):
|
||||||
Message.id != first_message.id,
|
Message.id != first_message.id,
|
||||||
)
|
)
|
||||||
.order_by(Message.created_at.desc())
|
.order_by(Message.created_at.desc())
|
||||||
.limit(args["limit"])
|
.limit(args.limit)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -215,12 +268,12 @@ class ChatMessageListApi(Resource):
|
||||||
db.session.query(Message)
|
db.session.query(Message)
|
||||||
.where(Message.conversation_id == conversation.id)
|
.where(Message.conversation_id == conversation.id)
|
||||||
.order_by(Message.created_at.desc())
|
.order_by(Message.created_at.desc())
|
||||||
.limit(args["limit"])
|
.limit(args.limit)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize has_more based on whether we have a full page
|
# Initialize has_more based on whether we have a full page
|
||||||
if len(history_messages) == args["limit"]:
|
if len(history_messages) == args.limit:
|
||||||
current_page_first_message = history_messages[-1]
|
current_page_first_message = history_messages[-1]
|
||||||
# Check if there are more messages before the current page
|
# Check if there are more messages before the current page
|
||||||
has_more = db.session.scalar(
|
has_more = db.session.scalar(
|
||||||
|
|
@ -238,7 +291,7 @@ class ChatMessageListApi(Resource):
|
||||||
|
|
||||||
history_messages = list(reversed(history_messages))
|
history_messages = list(reversed(history_messages))
|
||||||
|
|
||||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||||
|
|
@ -246,15 +299,7 @@ class MessageFeedbackApi(Resource):
|
||||||
@console_ns.doc("create_message_feedback")
|
@console_ns.doc("create_message_feedback")
|
||||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"MessageFeedbackRequest",
|
|
||||||
{
|
|
||||||
"message_id": fields.String(required=True, description="Message ID"),
|
|
||||||
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Feedback updated successfully")
|
@console_ns.response(200, "Feedback updated successfully")
|
||||||
@console_ns.response(404, "Message not found")
|
@console_ns.response(404, "Message not found")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
|
|
@ -265,14 +310,9 @@ class MessageFeedbackApi(Resource):
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("message_id", required=True, type=uuid_value, location="json")
|
|
||||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
message_id = str(args["message_id"])
|
message_id = str(args.message_id)
|
||||||
|
|
||||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||||
|
|
||||||
|
|
@ -281,18 +321,23 @@ class MessageFeedbackApi(Resource):
|
||||||
|
|
||||||
feedback = message.admin_feedback
|
feedback = message.admin_feedback
|
||||||
|
|
||||||
if not args["rating"] and feedback:
|
if not args.rating and feedback:
|
||||||
db.session.delete(feedback)
|
db.session.delete(feedback)
|
||||||
elif args["rating"] and feedback:
|
elif args.rating and feedback:
|
||||||
feedback.rating = args["rating"]
|
feedback.rating = args.rating
|
||||||
elif not args["rating"] and not feedback:
|
feedback.content = args.content
|
||||||
|
elif not args.rating and not feedback:
|
||||||
raise ValueError("rating cannot be None when feedback not exists")
|
raise ValueError("rating cannot be None when feedback not exists")
|
||||||
else:
|
else:
|
||||||
|
rating_value = args.rating
|
||||||
|
if rating_value is None:
|
||||||
|
raise ValueError("rating is required to create feedback")
|
||||||
feedback = MessageFeedback(
|
feedback = MessageFeedback(
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
conversation_id=message.conversation_id,
|
conversation_id=message.conversation_id,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
rating=args["rating"],
|
rating=rating_value,
|
||||||
|
content=args.content,
|
||||||
from_source="admin",
|
from_source="admin",
|
||||||
from_account_id=current_user.id,
|
from_account_id=current_user.id,
|
||||||
)
|
)
|
||||||
|
|
@ -369,24 +414,12 @@ class MessageSuggestedQuestionApi(Resource):
|
||||||
return {"data": questions}
|
return {"data": questions}
|
||||||
|
|
||||||
|
|
||||||
# Shared parser for feedback export (used for both documentation and runtime parsing)
|
|
||||||
feedback_export_parser = (
|
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
|
|
||||||
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
|
|
||||||
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
|
|
||||||
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
|
|
||||||
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
|
|
||||||
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
||||||
class MessageFeedbackExportApi(Resource):
|
class MessageFeedbackExportApi(Resource):
|
||||||
@console_ns.doc("export_feedbacks")
|
@console_ns.doc("export_feedbacks")
|
||||||
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(feedback_export_parser)
|
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
|
||||||
@console_ns.response(200, "Feedback data exported successfully")
|
@console_ns.response(200, "Feedback data exported successfully")
|
||||||
@console_ns.response(400, "Invalid parameters")
|
@console_ns.response(400, "Invalid parameters")
|
||||||
@console_ns.response(500, "Internal server error")
|
@console_ns.response(500, "Internal server error")
|
||||||
|
|
@ -395,7 +428,7 @@ class MessageFeedbackExportApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
args = feedback_export_parser.parse_args()
|
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
# Import the service function
|
# Import the service function
|
||||||
from services.feedback_service import FeedbackService
|
from services.feedback_service import FeedbackService
|
||||||
|
|
@ -403,12 +436,12 @@ class MessageFeedbackExportApi(Resource):
|
||||||
try:
|
try:
|
||||||
export_data = FeedbackService.export_feedbacks(
|
export_data = FeedbackService.export_feedbacks(
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
from_source=args.get("from_source"),
|
from_source=args.from_source,
|
||||||
rating=args.get("rating"),
|
rating=args.rating,
|
||||||
has_comment=args.get("has_comment"),
|
has_comment=args.has_comment,
|
||||||
start_date=args.get("start_date"),
|
start_date=args.start_date,
|
||||||
end_date=args.get("end_date"),
|
end_date=args.end_date,
|
||||||
format_type=args.get("format", "csv"),
|
format_type=args.format,
|
||||||
)
|
)
|
||||||
|
|
||||||
return export_data
|
return export_data
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
from flask_restx import Resource, fields, reqparse
|
from typing import Any
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.ops_service import OpsService
|
from services.ops_service import OpsService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class TraceProviderQuery(BaseModel):
|
||||||
|
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||||
|
|
||||||
|
|
||||||
|
class TraceConfigPayload(BaseModel):
|
||||||
|
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||||
|
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
TraceProviderQuery.__name__,
|
||||||
|
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
||||||
class TraceAppConfigApi(Resource):
|
class TraceAppConfigApi(Resource):
|
||||||
|
|
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
|
||||||
@console_ns.doc("get_trace_app_config")
|
@console_ns.doc("get_trace_app_config")
|
||||||
@console_ns.doc(description="Get tracing configuration for an application")
|
@console_ns.doc(description="Get tracing configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||||
console_ns.parser().add_argument(
|
|
||||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||||
)
|
)
|
||||||
|
|
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||||
if not trace_config:
|
if not trace_config:
|
||||||
return {"has_not_configured": True}
|
return {"has_not_configured": True}
|
||||||
return trace_config
|
return trace_config
|
||||||
|
|
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
|
||||||
@console_ns.doc("create_trace_app_config")
|
@console_ns.doc("create_trace_app_config")
|
||||||
@console_ns.doc(description="Create a new tracing configuration for an application")
|
@console_ns.doc(description="Create a new tracing configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"TraceConfigCreateRequest",
|
|
||||||
{
|
|
||||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
|
||||||
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||||
)
|
)
|
||||||
|
|
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
"""Create a new trace app configuration"""
|
"""Create a new trace app configuration"""
|
||||||
parser = (
|
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
|
||||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.create_tracing_app_config(
|
result = OpsService.create_tracing_app_config(
|
||||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigIsExist()
|
raise TracingConfigIsExist()
|
||||||
|
|
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
|
||||||
@console_ns.doc("update_trace_app_config")
|
@console_ns.doc("update_trace_app_config")
|
||||||
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"TraceConfigUpdateRequest",
|
|
||||||
{
|
|
||||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
|
||||||
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, app_id):
|
def patch(self, app_id):
|
||||||
"""Update an existing trace app configuration"""
|
"""Update an existing trace app configuration"""
|
||||||
parser = (
|
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
|
||||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.update_tracing_app_config(
|
result = OpsService.update_tracing_app_config(
|
||||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigNotExist()
|
raise TracingConfigNotExist()
|
||||||
|
|
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
|
||||||
@console_ns.doc("delete_trace_app_config")
|
@console_ns.doc("delete_trace_app_config")
|
||||||
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||||
console_ns.parser().add_argument(
|
|
||||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(204, "Tracing configuration deleted successfully")
|
@console_ns.response(204, "Tracing configuration deleted successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
"""Delete an existing trace app configuration"""
|
"""Delete an existing trace app configuration"""
|
||||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigNotExist()
|
raise TracingConfigNotExist()
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from typing import Literal
|
||||||
|
|
||||||
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
|
|
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Site
|
from models import Site
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class AppSiteUpdatePayload(BaseModel):
|
||||||
|
title: str | None = Field(default=None)
|
||||||
|
icon_type: str | None = Field(default=None)
|
||||||
|
icon: str | None = Field(default=None)
|
||||||
|
icon_background: str | None = Field(default=None)
|
||||||
|
description: str | None = Field(default=None)
|
||||||
|
default_language: str | None = Field(default=None)
|
||||||
|
chat_color_theme: str | None = Field(default=None)
|
||||||
|
chat_color_theme_inverted: bool | None = Field(default=None)
|
||||||
|
customize_domain: str | None = Field(default=None)
|
||||||
|
copyright: str | None = Field(default=None)
|
||||||
|
privacy_policy: str | None = Field(default=None)
|
||||||
|
custom_disclaimer: str | None = Field(default=None)
|
||||||
|
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
|
||||||
|
prompt_public: bool | None = Field(default=None)
|
||||||
|
show_workflow_steps: bool | None = Field(default=None)
|
||||||
|
use_icon_as_answer_icon: bool | None = Field(default=None)
|
||||||
|
|
||||||
|
@field_validator("default_language")
|
||||||
|
@classmethod
|
||||||
|
def validate_language(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return supported_language(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
AppSiteUpdatePayload.__name__,
|
||||||
|
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("title", type=str, required=False, location="json")
|
|
||||||
.add_argument("icon_type", type=str, required=False, location="json")
|
|
||||||
.add_argument("icon", type=str, required=False, location="json")
|
|
||||||
.add_argument("icon_background", type=str, required=False, location="json")
|
|
||||||
.add_argument("description", type=str, required=False, location="json")
|
|
||||||
.add_argument("default_language", type=supported_language, required=False, location="json")
|
|
||||||
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
|
||||||
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
|
||||||
.add_argument("customize_domain", type=str, required=False, location="json")
|
|
||||||
.add_argument("copyright", type=str, required=False, location="json")
|
|
||||||
.add_argument("privacy_policy", type=str, required=False, location="json")
|
|
||||||
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"customize_token_strategy",
|
|
||||||
type=str,
|
|
||||||
choices=["must", "allow", "not_allow"],
|
|
||||||
required=False,
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("prompt_public", type=bool, required=False, location="json")
|
|
||||||
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
|
||||||
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||||
class AppSite(Resource):
|
class AppSite(Resource):
|
||||||
@console_ns.doc("update_app_site")
|
@console_ns.doc("update_app_site")
|
||||||
@console_ns.doc(description="Update application site configuration")
|
@console_ns.doc(description="Update application site configuration")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AppSiteRequest",
|
|
||||||
{
|
|
||||||
"title": fields.String(description="Site title"),
|
|
||||||
"icon_type": fields.String(description="Icon type"),
|
|
||||||
"icon": fields.String(description="Icon"),
|
|
||||||
"icon_background": fields.String(description="Icon background color"),
|
|
||||||
"description": fields.String(description="Site description"),
|
|
||||||
"default_language": fields.String(description="Default language"),
|
|
||||||
"chat_color_theme": fields.String(description="Chat color theme"),
|
|
||||||
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
|
|
||||||
"customize_domain": fields.String(description="Custom domain"),
|
|
||||||
"copyright": fields.String(description="Copyright text"),
|
|
||||||
"privacy_policy": fields.String(description="Privacy policy"),
|
|
||||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
|
||||||
"customize_token_strategy": fields.String(
|
|
||||||
enum=["must", "allow", "not_allow"], description="Token strategy"
|
|
||||||
),
|
|
||||||
"prompt_public": fields.Boolean(description="Make prompt public"),
|
|
||||||
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
|
|
||||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.response(404, "App not found")
|
@console_ns.response(404, "App not found")
|
||||||
|
|
@ -89,7 +73,7 @@ class AppSite(Resource):
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_site_model)
|
@marshal_with(app_site_model)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
args = parse_app_site_args()
|
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||||
if not site:
|
if not site:
|
||||||
|
|
@ -113,7 +97,7 @@ class AppSite(Resource):
|
||||||
"show_workflow_steps",
|
"show_workflow_steps",
|
||||||
"use_icon_as_answer_icon",
|
"use_icon_as_answer_icon",
|
||||||
]:
|
]:
|
||||||
value = args.get(attr_name)
|
value = getattr(args, attr_name)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(site, attr_name, value)
|
setattr(site, attr_name, value)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import abort, jsonify
|
from flask import abort, jsonify, request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
|
@ -10,21 +11,37 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import parse_time_range
|
from libs.datetime_utils import parse_time_range
|
||||||
from libs.helper import DatetimeString, convert_datetime_to_date
|
from libs.helper import convert_datetime_to_date
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import AppMode
|
from models import AppMode
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class StatisticTimeRangeQuery(BaseModel):
|
||||||
|
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||||
|
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||||
|
|
||||||
|
@field_validator("start", "end", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def empty_string_to_none(cls, value: str | None) -> str | None:
|
||||||
|
if value == "":
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
StatisticTimeRangeQuery.__name__,
|
||||||
|
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||||
class DailyMessageStatistic(Resource):
|
class DailyMessageStatistic(Resource):
|
||||||
@console_ns.doc("get_daily_message_statistics")
|
@console_ns.doc("get_daily_message_statistics")
|
||||||
@console_ns.doc(description="Get daily message statistics for an application")
|
@console_ns.doc(description="Get daily message statistics for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
|
||||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Daily message statistics retrieved successfully",
|
"Daily message statistics retrieved successfully",
|
||||||
|
|
@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
converted_created_at = convert_datetime_to_date("created_at")
|
converted_created_at = convert_datetime_to_date("created_at")
|
||||||
sql_query = f"""SELECT
|
sql_query = f"""SELECT
|
||||||
|
|
@ -57,7 +69,7 @@ WHERE
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -81,19 +93,12 @@ WHERE
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|
||||||
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
|
||||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||||
class DailyConversationStatistic(Resource):
|
class DailyConversationStatistic(Resource):
|
||||||
@console_ns.doc("get_daily_conversation_statistics")
|
@console_ns.doc("get_daily_conversation_statistics")
|
||||||
@console_ns.doc(description="Get daily conversation statistics for an application")
|
@console_ns.doc(description="Get daily conversation statistics for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Daily conversation statistics retrieved successfully",
|
"Daily conversation statistics retrieved successfully",
|
||||||
|
|
@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
converted_created_at = convert_datetime_to_date("created_at")
|
converted_created_at = convert_datetime_to_date("created_at")
|
||||||
sql_query = f"""SELECT
|
sql_query = f"""SELECT
|
||||||
|
|
@ -121,7 +126,7 @@ WHERE
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -149,7 +154,7 @@ class DailyTerminalsStatistic(Resource):
|
||||||
@console_ns.doc("get_daily_terminals_statistics")
|
@console_ns.doc("get_daily_terminals_statistics")
|
||||||
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Daily terminal statistics retrieved successfully",
|
"Daily terminal statistics retrieved successfully",
|
||||||
|
|
@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
converted_created_at = convert_datetime_to_date("created_at")
|
converted_created_at = convert_datetime_to_date("created_at")
|
||||||
sql_query = f"""SELECT
|
sql_query = f"""SELECT
|
||||||
|
|
@ -177,7 +182,7 @@ WHERE
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -206,7 +211,7 @@ class DailyTokenCostStatistic(Resource):
|
||||||
@console_ns.doc("get_daily_token_cost_statistics")
|
@console_ns.doc("get_daily_token_cost_statistics")
|
||||||
@console_ns.doc(description="Get daily token cost statistics for an application")
|
@console_ns.doc(description="Get daily token cost statistics for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Daily token cost statistics retrieved successfully",
|
"Daily token cost statistics retrieved successfully",
|
||||||
|
|
@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
converted_created_at = convert_datetime_to_date("created_at")
|
converted_created_at = convert_datetime_to_date("created_at")
|
||||||
sql_query = f"""SELECT
|
sql_query = f"""SELECT
|
||||||
|
|
@ -235,7 +240,7 @@ WHERE
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -266,7 +271,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||||
@console_ns.doc("get_average_session_interaction_statistics")
|
@console_ns.doc("get_average_session_interaction_statistics")
|
||||||
@console_ns.doc(description="Get average session interaction statistics for an application")
|
@console_ns.doc(description="Get average session interaction statistics for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Average session interaction statistics retrieved successfully",
|
"Average session interaction statistics retrieved successfully",
|
||||||
|
|
@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||||
sql_query = f"""SELECT
|
sql_query = f"""SELECT
|
||||||
|
|
@ -302,7 +307,7 @@ FROM
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -342,7 +347,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||||
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
||||||
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"User satisfaction rate statistics retrieved successfully",
|
"User satisfaction rate statistics retrieved successfully",
|
||||||
|
|
@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||||
sql_query = f"""SELECT
|
sql_query = f"""SELECT
|
||||||
|
|
@ -374,7 +379,7 @@ WHERE
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -408,7 +413,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||||
@console_ns.doc("get_average_response_time_statistics")
|
@console_ns.doc("get_average_response_time_statistics")
|
||||||
@console_ns.doc(description="Get average response time statistics for an application")
|
@console_ns.doc(description="Get average response time statistics for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Average response time statistics retrieved successfully",
|
"Average response time statistics retrieved successfully",
|
||||||
|
|
@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
converted_created_at = convert_datetime_to_date("created_at")
|
converted_created_at = convert_datetime_to_date("created_at")
|
||||||
sql_query = f"""SELECT
|
sql_query = f"""SELECT
|
||||||
|
|
@ -436,7 +441,7 @@ WHERE
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -465,7 +470,7 @@ class TokensPerSecondStatistic(Resource):
|
||||||
@console_ns.doc("get_tokens_per_second_statistics")
|
@console_ns.doc("get_tokens_per_second_statistics")
|
||||||
@console_ns.doc(description="Get tokens per second statistics for an application")
|
@console_ns.doc(description="Get tokens per second statistics for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Tokens per second statistics retrieved successfully",
|
"Tokens per second statistics retrieved successfully",
|
||||||
|
|
@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
args = parser.parse_args()
|
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
converted_created_at = convert_datetime_to_date("created_at")
|
converted_created_at = convert_datetime_to_date("created_at")
|
||||||
sql_query = f"""SELECT
|
sql_query = f"""SELECT
|
||||||
|
|
@ -495,7 +500,7 @@ WHERE
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import cast
|
from typing import Any
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
|
|
@ -49,6 +50,7 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
LISTENING_RETRY_IN = 2000
|
LISTENING_RETRY_IN = 2000
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||||
# Register in dependency order: base models first, then dependent models
|
# Register in dependency order: base models first, then dependent models
|
||||||
|
|
@ -107,6 +109,104 @@ if workflow_run_node_execution_model is None:
|
||||||
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
||||||
|
|
||||||
|
|
||||||
|
class SyncDraftWorkflowPayload(BaseModel):
|
||||||
|
graph: dict[str, Any]
|
||||||
|
features: dict[str, Any]
|
||||||
|
hash: str | None = None
|
||||||
|
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseWorkflowRunPayload(BaseModel):
|
||||||
|
files: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||||
|
inputs: dict[str, Any] | None = None
|
||||||
|
query: str = ""
|
||||||
|
conversation_id: str | None = None
|
||||||
|
parent_message_id: str | None = None
|
||||||
|
|
||||||
|
@field_validator("conversation_id", "parent_message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_uuid(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
class IterationNodeRunPayload(BaseModel):
|
||||||
|
inputs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoopNodeRunPayload(BaseModel):
|
||||||
|
inputs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
query: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class PublishWorkflowPayload(BaseModel):
|
||||||
|
marked_name: str | None = Field(default=None, max_length=20)
|
||||||
|
marked_comment: str | None = Field(default=None, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultBlockConfigQuery(BaseModel):
|
||||||
|
q: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertToWorkflowPayload(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
icon_type: str | None = None
|
||||||
|
icon: str | None = None
|
||||||
|
icon_background: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowListQuery(BaseModel):
|
||||||
|
page: int = Field(default=1, ge=1, le=99999)
|
||||||
|
limit: int = Field(default=10, ge=1, le=100)
|
||||||
|
user_id: str | None = None
|
||||||
|
named_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowUpdatePayload(BaseModel):
|
||||||
|
marked_name: str | None = Field(default=None, max_length=20)
|
||||||
|
marked_comment: str | None = Field(default=None, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||||
|
node_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||||
|
node_ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
def reg(cls: type[BaseModel]):
|
||||||
|
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
reg(SyncDraftWorkflowPayload)
|
||||||
|
reg(AdvancedChatWorkflowRunPayload)
|
||||||
|
reg(IterationNodeRunPayload)
|
||||||
|
reg(LoopNodeRunPayload)
|
||||||
|
reg(DraftWorkflowRunPayload)
|
||||||
|
reg(DraftWorkflowNodeRunPayload)
|
||||||
|
reg(PublishWorkflowPayload)
|
||||||
|
reg(DefaultBlockConfigQuery)
|
||||||
|
reg(ConvertToWorkflowPayload)
|
||||||
|
reg(WorkflowListQuery)
|
||||||
|
reg(WorkflowUpdatePayload)
|
||||||
|
reg(DraftWorkflowTriggerRunPayload)
|
||||||
|
reg(DraftWorkflowTriggerRunAllPayload)
|
||||||
|
|
||||||
|
|
||||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||||
# at the controller level rather than in the workflow logic. This would improve separation
|
# at the controller level rather than in the workflow logic. This would improve separation
|
||||||
# of concerns and make the code more maintainable.
|
# of concerns and make the code more maintainable.
|
||||||
|
|
@ -158,18 +258,7 @@ class DraftWorkflowApi(Resource):
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@console_ns.doc("sync_draft_workflow")
|
@console_ns.doc("sync_draft_workflow")
|
||||||
@console_ns.doc(description="Sync draft workflow configuration")
|
@console_ns.doc(description="Sync draft workflow configuration")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"SyncDraftWorkflowRequest",
|
|
||||||
{
|
|
||||||
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
|
|
||||||
"features": fields.Raw(required=True, description="Workflow features configuration"),
|
|
||||||
"hash": fields.String(description="Workflow hash for validation"),
|
|
||||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
|
||||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Draft workflow synced successfully",
|
"Draft workflow synced successfully",
|
||||||
|
|
@ -193,36 +282,23 @@ class DraftWorkflowApi(Resource):
|
||||||
|
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
|
||||||
|
payload_data: dict[str, Any] | None = None
|
||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
parser = (
|
payload_data = request.get_json(silent=True)
|
||||||
reqparse.RequestParser()
|
if not isinstance(payload_data, dict):
|
||||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
return {"message": "Invalid JSON data"}, 400
|
||||||
.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("hash", type=str, required=False, location="json")
|
|
||||||
.add_argument("environment_variables", type=list, required=True, location="json")
|
|
||||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
elif "text/plain" in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
data = json.loads(request.data.decode("utf-8"))
|
payload_data = json.loads(request.data.decode("utf-8"))
|
||||||
if "graph" not in data or "features" not in data:
|
|
||||||
raise ValueError("graph or features not found in data")
|
|
||||||
|
|
||||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
|
||||||
raise ValueError("graph or features is not a dict")
|
|
||||||
|
|
||||||
args = {
|
|
||||||
"graph": data.get("graph"),
|
|
||||||
"features": data.get("features"),
|
|
||||||
"hash": data.get("hash"),
|
|
||||||
"environment_variables": data.get("environment_variables"),
|
|
||||||
"conversation_variables": data.get("conversation_variables"),
|
|
||||||
}
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return {"message": "Invalid JSON data"}, 400
|
return {"message": "Invalid JSON data"}, 400
|
||||||
|
if not isinstance(payload_data, dict):
|
||||||
|
return {"message": "Invalid JSON data"}, 400
|
||||||
else:
|
else:
|
||||||
abort(415)
|
abort(415)
|
||||||
|
|
||||||
|
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||||
|
args = args_model.model_dump()
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -258,17 +334,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||||
@console_ns.doc("run_advanced_chat_draft_workflow")
|
@console_ns.doc("run_advanced_chat_draft_workflow")
|
||||||
@console_ns.doc(description="Run draft workflow for advanced chat application")
|
@console_ns.doc(description="Run draft workflow for advanced chat application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AdvancedChatWorkflowRunRequest",
|
|
||||||
{
|
|
||||||
"query": fields.String(required=True, description="User query"),
|
|
||||||
"inputs": fields.Raw(description="Input variables"),
|
|
||||||
"files": fields.List(fields.Raw, description="File uploads"),
|
|
||||||
"conversation_id": fields.String(description="Conversation ID"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Workflow run started successfully")
|
@console_ns.response(200, "Workflow run started successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
|
|
@ -283,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
args = args_model.model_dump(exclude_none=True)
|
||||||
.add_argument("inputs", type=dict, location="json")
|
|
||||||
.add_argument("query", type=str, required=True, location="json", default="")
|
|
||||||
.add_argument("files", type=list, location="json")
|
|
||||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
|
||||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
external_trace_id = get_external_trace_id(request)
|
external_trace_id = get_external_trace_id(request)
|
||||||
if external_trace_id:
|
if external_trace_id:
|
||||||
|
|
@ -322,15 +380,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||||
@console_ns.doc("run_advanced_chat_draft_iteration_node")
|
@console_ns.doc("run_advanced_chat_draft_iteration_node")
|
||||||
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
|
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"IterationNodeRunRequest",
|
|
||||||
{
|
|
||||||
"task_id": fields.String(required=True, description="Task ID"),
|
|
||||||
"inputs": fields.Raw(description="Input variables"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Iteration node run started successfully")
|
@console_ns.response(200, "Iteration node run started successfully")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
@console_ns.response(404, "Node not found")
|
@console_ns.response(404, "Node not found")
|
||||||
|
|
@ -344,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_single_iteration(
|
response = AppGenerateService.generate_single_iteration(
|
||||||
|
|
@ -369,15 +418,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||||
@console_ns.doc("run_workflow_draft_iteration_node")
|
@console_ns.doc("run_workflow_draft_iteration_node")
|
||||||
@console_ns.doc(description="Run draft workflow iteration node")
|
@console_ns.doc(description="Run draft workflow iteration node")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"WorkflowIterationNodeRunRequest",
|
|
||||||
{
|
|
||||||
"task_id": fields.String(required=True, description="Task ID"),
|
|
||||||
"inputs": fields.Raw(description="Input variables"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Workflow iteration node run started successfully")
|
@console_ns.response(200, "Workflow iteration node run started successfully")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
@console_ns.response(404, "Node not found")
|
@console_ns.response(404, "Node not found")
|
||||||
|
|
@ -391,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_single_iteration(
|
response = AppGenerateService.generate_single_iteration(
|
||||||
|
|
@ -416,15 +456,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||||
@console_ns.doc("run_advanced_chat_draft_loop_node")
|
@console_ns.doc("run_advanced_chat_draft_loop_node")
|
||||||
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
|
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"LoopNodeRunRequest",
|
|
||||||
{
|
|
||||||
"task_id": fields.String(required=True, description="Task ID"),
|
|
||||||
"inputs": fields.Raw(description="Input variables"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Loop node run started successfully")
|
@console_ns.response(200, "Loop node run started successfully")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
@console_ns.response(404, "Node not found")
|
@console_ns.response(404, "Node not found")
|
||||||
|
|
@ -438,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_single_loop(
|
response = AppGenerateService.generate_single_loop(
|
||||||
|
|
@ -463,15 +494,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||||
@console_ns.doc("run_workflow_draft_loop_node")
|
@console_ns.doc("run_workflow_draft_loop_node")
|
||||||
@console_ns.doc(description="Run draft workflow loop node")
|
@console_ns.doc(description="Run draft workflow loop node")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"WorkflowLoopNodeRunRequest",
|
|
||||||
{
|
|
||||||
"task_id": fields.String(required=True, description="Task ID"),
|
|
||||||
"inputs": fields.Raw(description="Input variables"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Workflow loop node run started successfully")
|
@console_ns.response(200, "Workflow loop node run started successfully")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
@console_ns.response(404, "Node not found")
|
@console_ns.response(404, "Node not found")
|
||||||
|
|
@ -485,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_single_loop(
|
response = AppGenerateService.generate_single_loop(
|
||||||
|
|
@ -510,15 +532,7 @@ class DraftWorkflowRunApi(Resource):
|
||||||
@console_ns.doc("run_draft_workflow")
|
@console_ns.doc("run_draft_workflow")
|
||||||
@console_ns.doc(description="Run draft workflow")
|
@console_ns.doc(description="Run draft workflow")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"DraftWorkflowRunRequest",
|
|
||||||
{
|
|
||||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
|
||||||
"files": fields.List(fields.Raw, description="File uploads"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Draft workflow run started successfully")
|
@console_ns.response(200, "Draft workflow run started successfully")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -531,12 +545,7 @@ class DraftWorkflowRunApi(Resource):
|
||||||
Run draft workflow
|
Run draft workflow
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = (
|
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("files", type=list, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
external_trace_id = get_external_trace_id(request)
|
external_trace_id = get_external_trace_id(request)
|
||||||
if external_trace_id:
|
if external_trace_id:
|
||||||
|
|
@ -588,14 +597,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||||
@console_ns.doc("run_draft_workflow_node")
|
@console_ns.doc("run_draft_workflow_node")
|
||||||
@console_ns.doc(description="Run draft workflow node")
|
@console_ns.doc(description="Run draft workflow node")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"DraftWorkflowNodeRunRequest",
|
|
||||||
{
|
|
||||||
"inputs": fields.Raw(description="Input variables"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
|
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
@console_ns.response(404, "Node not found")
|
@console_ns.response(404, "Node not found")
|
||||||
|
|
@ -610,15 +612,10 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||||
Run draft workflow node
|
Run draft workflow node
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = (
|
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
args = args_model.model_dump(exclude_none=True)
|
||||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("query", type=str, required=False, location="json", default="")
|
|
||||||
.add_argument("files", type=list, location="json", default=[])
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_inputs = args.get("inputs")
|
user_inputs = args_model.inputs
|
||||||
if user_inputs is None:
|
if user_inputs is None:
|
||||||
raise ValueError("missing inputs")
|
raise ValueError("missing inputs")
|
||||||
|
|
||||||
|
|
@ -643,13 +640,6 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
|
||||||
parser_publish = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
|
||||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||||
class PublishedWorkflowApi(Resource):
|
class PublishedWorkflowApi(Resource):
|
||||||
@console_ns.doc("get_published_workflow")
|
@console_ns.doc("get_published_workflow")
|
||||||
|
|
@ -674,7 +664,7 @@ class PublishedWorkflowApi(Resource):
|
||||||
# return workflow, if not found, return None
|
# return workflow, if not found, return None
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
@console_ns.expect(parser_publish)
|
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -686,13 +676,7 @@ class PublishedWorkflowApi(Resource):
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_publish.parse_args()
|
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
# Validate name and comment length
|
|
||||||
if args.marked_name and len(args.marked_name) > 20:
|
|
||||||
raise ValueError("Marked name cannot exceed 20 characters")
|
|
||||||
if args.marked_comment and len(args.marked_comment) > 100:
|
|
||||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|
@ -741,9 +725,6 @@ class DefaultBlockConfigsApi(Resource):
|
||||||
return workflow_service.get_default_block_configs()
|
return workflow_service.get_default_block_configs()
|
||||||
|
|
||||||
|
|
||||||
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||||
class DefaultBlockConfigApi(Resource):
|
class DefaultBlockConfigApi(Resource):
|
||||||
@console_ns.doc("get_default_block_config")
|
@console_ns.doc("get_default_block_config")
|
||||||
|
|
@ -751,7 +732,7 @@ class DefaultBlockConfigApi(Resource):
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||||
@console_ns.response(200, "Default block configuration retrieved successfully")
|
@console_ns.response(200, "Default block configuration retrieved successfully")
|
||||||
@console_ns.response(404, "Block type not found")
|
@console_ns.response(404, "Block type not found")
|
||||||
@console_ns.expect(parser_block)
|
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -761,14 +742,12 @@ class DefaultBlockConfigApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
args = parser_block.parse_args()
|
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
q = args.get("q")
|
|
||||||
|
|
||||||
filters = None
|
filters = None
|
||||||
if q:
|
if args.q:
|
||||||
try:
|
try:
|
||||||
filters = json.loads(args.get("q", ""))
|
filters = json.loads(args.q)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError("Invalid filters")
|
raise ValueError("Invalid filters")
|
||||||
|
|
||||||
|
|
@ -777,18 +756,9 @@ class DefaultBlockConfigApi(Resource):
|
||||||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||||
|
|
||||||
|
|
||||||
parser_convert = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||||
class ConvertToWorkflowApi(Resource):
|
class ConvertToWorkflowApi(Resource):
|
||||||
@console_ns.expect(parser_convert)
|
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
|
||||||
@console_ns.doc("convert_to_workflow")
|
@console_ns.doc("convert_to_workflow")
|
||||||
@console_ns.doc(description="Convert application to workflow mode")
|
@console_ns.doc(description="Convert application to workflow mode")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
|
|
@ -808,10 +778,8 @@ class ConvertToWorkflowApi(Resource):
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
if request.data:
|
payload = console_ns.payload or {}
|
||||||
args = parser_convert.parse_args()
|
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
|
||||||
else:
|
|
||||||
args = {}
|
|
||||||
|
|
||||||
# convert to workflow mode
|
# convert to workflow mode
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
@ -823,18 +791,9 @@ class ConvertToWorkflowApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
parser_workflows = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
|
||||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
|
||||||
.add_argument("user_id", type=str, required=False, location="args")
|
|
||||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||||
class PublishedAllWorkflowApi(Resource):
|
class PublishedAllWorkflowApi(Resource):
|
||||||
@console_ns.expect(parser_workflows)
|
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||||
@console_ns.doc("get_all_published_workflows")
|
@console_ns.doc("get_all_published_workflows")
|
||||||
@console_ns.doc(description="Get all published workflows for an application")
|
@console_ns.doc(description="Get all published workflows for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
|
|
@ -851,16 +810,15 @@ class PublishedAllWorkflowApi(Resource):
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_workflows.parse_args()
|
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
page = args["page"]
|
page = args.page
|
||||||
limit = args["limit"]
|
limit = args.limit
|
||||||
user_id = args.get("user_id")
|
user_id = args.user_id
|
||||||
named_only = args.get("named_only", False)
|
named_only = args.named_only
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
if user_id != current_user.id:
|
if user_id != current_user.id:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
user_id = cast(str, user_id)
|
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|
@ -886,15 +844,7 @@ class WorkflowByIdApi(Resource):
|
||||||
@console_ns.doc("update_workflow_by_id")
|
@console_ns.doc("update_workflow_by_id")
|
||||||
@console_ns.doc(description="Update workflow by ID")
|
@console_ns.doc(description="Update workflow by ID")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"UpdateWorkflowRequest",
|
|
||||||
{
|
|
||||||
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
|
|
||||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
||||||
@console_ns.response(404, "Workflow not found")
|
@console_ns.response(404, "Workflow not found")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
|
|
@ -909,25 +859,14 @@ class WorkflowByIdApi(Resource):
|
||||||
Update workflow attributes
|
Update workflow attributes
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = (
|
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("marked_name", type=str, required=False, location="json")
|
|
||||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate name and comment length
|
|
||||||
if args.marked_name and len(args.marked_name) > 20:
|
|
||||||
raise ValueError("Marked name cannot exceed 20 characters")
|
|
||||||
if args.marked_comment and len(args.marked_comment) > 100:
|
|
||||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
|
||||||
|
|
||||||
# Prepare update data
|
# Prepare update data
|
||||||
update_data = {}
|
update_data = {}
|
||||||
if args.get("marked_name") is not None:
|
if args.marked_name is not None:
|
||||||
update_data["marked_name"] = args["marked_name"]
|
update_data["marked_name"] = args.marked_name
|
||||||
if args.get("marked_comment") is not None:
|
if args.marked_comment is not None:
|
||||||
update_data["marked_comment"] = args["marked_comment"]
|
update_data["marked_comment"] = args.marked_comment
|
||||||
|
|
||||||
if not update_data:
|
if not update_data:
|
||||||
return {"message": "No valid fields to update"}, 400
|
return {"message": "No valid fields to update"}, 400
|
||||||
|
|
@ -1040,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||||
Poll for trigger events and execute full workflow when event arrives
|
Poll for trigger events and execute full workflow when event arrives
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument(
|
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
|
||||||
"node_id", type=str, required=True, location="json", nullable=False
|
node_id = args.node_id
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
node_id = args["node_id"]
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||||
if not draft_workflow:
|
if not draft_workflow:
|
||||||
|
|
@ -1172,14 +1108,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||||
@console_ns.doc("draft_workflow_trigger_run_all")
|
@console_ns.doc("draft_workflow_trigger_run_all")
|
||||||
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
|
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"DraftWorkflowTriggerRunAllRequest",
|
|
||||||
{
|
|
||||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Workflow executed successfully")
|
@console_ns.response(200, "Workflow executed successfully")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
@console_ns.response(500, "Internal server error")
|
@console_ns.response(500, "Internal server error")
|
||||||
|
|
@ -1194,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument(
|
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
|
||||||
"node_ids", type=list, required=True, location="json", nullable=False
|
node_ids = args.node_ids
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
node_ids = args["node_ids"]
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||||
if not draft_workflow:
|
if not draft_workflow:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask import request
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -14,6 +17,48 @@ from models import App
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.workflow_app_service import WorkflowAppService
|
from services.workflow_app_service import WorkflowAppService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAppLogQuery(BaseModel):
|
||||||
|
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||||
|
status: WorkflowExecutionStatus | None = Field(
|
||||||
|
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
|
||||||
|
)
|
||||||
|
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
|
||||||
|
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
|
||||||
|
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
|
||||||
|
created_by_account: str | None = Field(default=None, description="Filter by account")
|
||||||
|
detail: bool = Field(default=False, description="Whether to return detailed logs")
|
||||||
|
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||||
|
|
||||||
|
@field_validator("created_at__before", "created_at__after", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def parse_datetime(cls, value: str | None) -> datetime | None:
|
||||||
|
if value in (None, ""):
|
||||||
|
return None
|
||||||
|
return isoparse(value) # type: ignore
|
||||||
|
|
||||||
|
@field_validator("detail", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def parse_bool(cls, value: bool | str | None) -> bool:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
lowered = value.lower()
|
||||||
|
if lowered in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if lowered in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
raise ValueError("Invalid boolean value for detail")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
|
||||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||||
|
|
||||||
|
|
@ -23,19 +68,7 @@ class WorkflowAppLogApi(Resource):
|
||||||
@console_ns.doc("get_workflow_app_logs")
|
@console_ns.doc("get_workflow_app_logs")
|
||||||
@console_ns.doc(description="Get workflow application execution logs")
|
@console_ns.doc(description="Get workflow application execution logs")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.doc(
|
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||||
params={
|
|
||||||
"keyword": "Search keyword for filtering logs",
|
|
||||||
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
|
|
||||||
"created_at__before": "Filter logs created before this timestamp",
|
|
||||||
"created_at__after": "Filter logs created after this timestamp",
|
|
||||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
|
||||||
"created_by_account": "Filter by account",
|
|
||||||
"detail": "Whether to return detailed logs",
|
|
||||||
"page": "Page number (1-99999)",
|
|
||||||
"limit": "Number of items per page (1-100)",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -46,44 +79,7 @@ class WorkflowAppLogApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get workflow app logs
|
Get workflow app logs
|
||||||
"""
|
"""
|
||||||
parser = (
|
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("keyword", type=str, location="args")
|
|
||||||
.add_argument(
|
|
||||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"created_by_end_user_session_id",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"created_by_account",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
.add_argument("detail", type=bool, location="args", required=False, default=False)
|
|
||||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
|
||||||
if args.created_at__before:
|
|
||||||
args.created_at__before = isoparse(args.created_at__before)
|
|
||||||
|
|
||||||
if args.created_at__after:
|
|
||||||
args.created_at__after = isoparse(args.created_at__after)
|
|
||||||
|
|
||||||
# get paginate workflow app logs
|
# get paginate workflow app logs
|
||||||
workflow_app_service = WorkflowAppService()
|
workflow_app_service = WorkflowAppService()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import NoReturn, ParamSpec, TypeVar
|
from typing import Any, NoReturn, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response, request
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowDraftVariableListQuery(BaseModel):
|
||||||
|
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||||
|
name: str | None = Field(default=None, description="Variable name")
|
||||||
|
value: Any | None = Field(default=None, description="Variable value")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
WorkflowDraftVariableListQuery.__name__,
|
||||||
|
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
WorkflowDraftVariableUpdatePayload.__name__,
|
||||||
|
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _convert_values_to_json_serializable_object(value: Segment):
|
def _convert_values_to_json_serializable_object(value: Segment):
|
||||||
|
|
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||||
return _convert_values_to_json_serializable_object(value)
|
return _convert_values_to_json_serializable_object(value)
|
||||||
|
|
||||||
|
|
||||||
def _create_pagination_parser():
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"page",
|
|
||||||
type=inputs.int_range(1, 100_000),
|
|
||||||
required=False,
|
|
||||||
default=1,
|
|
||||||
location="args",
|
|
||||||
help="the page of data requested",
|
|
||||||
)
|
|
||||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||||
value_type = workflow_draft_var.value_type
|
value_type = workflow_draft_var.value_type
|
||||||
return value_type.exposed_type().value
|
return value_type.exposed_type().value
|
||||||
|
|
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||||
class WorkflowVariableCollectionApi(Resource):
|
class WorkflowVariableCollectionApi(Resource):
|
||||||
@console_ns.expect(_create_pagination_parser())
|
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||||
@console_ns.doc("get_workflow_variables")
|
@console_ns.doc("get_workflow_variables")
|
||||||
@console_ns.doc(description="Get draft workflow variables")
|
@console_ns.doc(description="Get draft workflow variables")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
|
|
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get draft workflow
|
Get draft workflow
|
||||||
"""
|
"""
|
||||||
parser = _create_pagination_parser()
|
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
@ -323,15 +328,7 @@ class VariableApi(Resource):
|
||||||
|
|
||||||
@console_ns.doc("update_variable")
|
@console_ns.doc("update_variable")
|
||||||
@console_ns.doc(description="Update a workflow variable")
|
@console_ns.doc(description="Update a workflow variable")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"UpdateVariableRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(description="Variable name"),
|
|
||||||
"value": fields.Raw(description="Variable value"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||||
@console_ns.response(404, "Variable not found")
|
@console_ns.response(404, "Variable not found")
|
||||||
@_api_prerequisite
|
@_api_prerequisite
|
||||||
|
|
@ -358,16 +355,10 @@ class VariableApi(Resource):
|
||||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
# }
|
# }
|
||||||
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
session=db.session(),
|
session=db.session(),
|
||||||
)
|
)
|
||||||
args = parser.parse_args(strict=True)
|
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
|
|
@ -375,8 +366,8 @@ class VariableApi(Resource):
|
||||||
if variable.app_id != app_model.id:
|
if variable.app_id != app_model.id:
|
||||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
new_name = args_model.name
|
||||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
raw_value = args_model.value
|
||||||
if new_name is None and raw_value is None:
|
if new_name is None and raw_value is None:
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import cast
|
from typing import Literal, cast
|
||||||
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask import request
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
|
@ -92,70 +93,51 @@ workflow_run_node_execution_list_model = console_ns.model(
|
||||||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
def _parse_workflow_run_list_args():
|
|
||||||
"""
|
|
||||||
Parse common arguments for workflow run list endpoints.
|
|
||||||
|
|
||||||
Returns:
|
class WorkflowRunListQuery(BaseModel):
|
||||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||||
"""
|
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||||
parser = (
|
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||||
reqparse.RequestParser()
|
default=None, description="Workflow run status filter"
|
||||||
.add_argument("last_id", type=uuid_value, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
.add_argument(
|
|
||||||
"status",
|
|
||||||
type=str,
|
|
||||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"triggered_from",
|
|
||||||
type=str,
|
|
||||||
choices=["debugging", "app-run"],
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
help="Filter by trigger source: debugging or app-run",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||||
|
default=None, description="Filter by trigger source: debugging or app-run"
|
||||||
|
|
||||||
def _parse_workflow_run_count_args():
|
|
||||||
"""
|
|
||||||
Parse common arguments for workflow run count endpoints.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed arguments containing status, time_range, and triggered_from filters
|
|
||||||
"""
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"status",
|
|
||||||
type=str,
|
|
||||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"time_range",
|
|
||||||
type=time_duration,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"triggered_from",
|
|
||||||
type=str,
|
|
||||||
choices=["debugging", "app-run"],
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
help="Filter by trigger source: debugging or app-run",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
|
||||||
|
@field_validator("last_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_last_id(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunCountQuery(BaseModel):
|
||||||
|
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||||
|
default=None, description="Workflow run status filter"
|
||||||
|
)
|
||||||
|
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
|
||||||
|
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||||
|
default=None, description="Filter by trigger source: debugging or app-run"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("time_range")
|
||||||
|
@classmethod
|
||||||
|
def validate_time_range(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return time_duration(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
WorkflowRunCountQuery.__name__,
|
||||||
|
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||||
|
|
@ -170,6 +152,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||||
@console_ns.doc(
|
@console_ns.doc(
|
||||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||||
)
|
)
|
||||||
|
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -180,12 +163,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get advanced chat app workflow run list
|
Get advanced chat app workflow run list
|
||||||
"""
|
"""
|
||||||
args = _parse_workflow_run_list_args()
|
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
args = args_model.model_dump(exclude_none=True)
|
||||||
|
|
||||||
# Default to DEBUGGING if not specified
|
# Default to DEBUGGING if not specified
|
||||||
triggered_from = (
|
triggered_from = (
|
||||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||||
if args.get("triggered_from")
|
if args_model.triggered_from
|
||||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -217,6 +201,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||||
)
|
)
|
||||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||||
|
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -226,12 +211,13 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get advanced chat workflow runs count statistics
|
Get advanced chat workflow runs count statistics
|
||||||
"""
|
"""
|
||||||
args = _parse_workflow_run_count_args()
|
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
args = args_model.model_dump(exclude_none=True)
|
||||||
|
|
||||||
# Default to DEBUGGING if not specified
|
# Default to DEBUGGING if not specified
|
||||||
triggered_from = (
|
triggered_from = (
|
||||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||||
if args.get("triggered_from")
|
if args_model.triggered_from
|
||||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -259,6 +245,7 @@ class WorkflowRunListApi(Resource):
|
||||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||||
)
|
)
|
||||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||||
|
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -268,12 +255,13 @@ class WorkflowRunListApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get workflow run list
|
Get workflow run list
|
||||||
"""
|
"""
|
||||||
args = _parse_workflow_run_list_args()
|
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
args = args_model.model_dump(exclude_none=True)
|
||||||
|
|
||||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||||
triggered_from = (
|
triggered_from = (
|
||||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||||
if args.get("triggered_from")
|
if args_model.triggered_from
|
||||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -305,6 +293,7 @@ class WorkflowRunCountApi(Resource):
|
||||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||||
)
|
)
|
||||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||||
|
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -314,12 +303,13 @@ class WorkflowRunCountApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get workflow runs count statistics
|
Get workflow runs count statistics
|
||||||
"""
|
"""
|
||||||
args = _parse_workflow_run_count_args()
|
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
args = args_model.model_dump(exclude_none=True)
|
||||||
|
|
||||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||||
triggered_from = (
|
triggered_from = (
|
||||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||||
if args.get("triggered_from")
|
if args_model.triggered_from
|
||||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from flask import abort, jsonify
|
from flask import abort, jsonify, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -7,12 +8,31 @@ from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import parse_time_range
|
from libs.datetime_utils import parse_time_range
|
||||||
from libs.helper import DatetimeString
|
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowStatisticQuery(BaseModel):
|
||||||
|
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||||
|
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
|
||||||
|
|
||||||
|
@field_validator("start", "end", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def blank_to_none(cls, value: str | None) -> str | None:
|
||||||
|
if value == "":
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
WorkflowStatisticQuery.__name__,
|
||||||
|
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||||
class WorkflowDailyRunsStatistic(Resource):
|
class WorkflowDailyRunsStatistic(Resource):
|
||||||
|
|
@ -24,9 +44,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||||
@console_ns.doc("get_workflow_daily_runs_statistic")
|
@console_ns.doc("get_workflow_daily_runs_statistic")
|
||||||
@console_ns.doc(description="Get workflow daily runs statistics")
|
@console_ns.doc(description="Get workflow daily runs statistics")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.doc(
|
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -35,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -71,9 +84,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||||
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
||||||
@console_ns.doc(description="Get workflow daily terminals statistics")
|
@console_ns.doc(description="Get workflow daily terminals statistics")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.doc(
|
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -82,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -118,9 +124,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||||
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
||||||
@console_ns.doc(description="Get workflow daily token cost statistics")
|
@console_ns.doc(description="Get workflow daily token cost statistics")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.doc(
|
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -129,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
@ -165,9 +164,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||||
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
||||||
@console_ns.doc(description="Get workflow average app interaction statistics")
|
@console_ns.doc(description="Get workflow average app interaction statistics")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.doc(
|
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -176,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account, _ = current_account_with_tenant()
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
assert account.timezone is not None
|
assert account.timezone is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
abort(400, description=str(e))
|
abort(400, description=str(e))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||||
class AppTriggerEnableApi(Resource):
|
class AppTriggerEnableApi(Resource):
|
||||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
@console_ns.expect(console_ns.models[ParserEnable.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,53 @@
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.error import AlreadyActivateError
|
from controllers.console.error import AlreadyActivateError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
from libs.helper import EmailStr, timezone
|
||||||
from models import AccountStatus
|
from models import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService
|
from services.account_service import RegisterService
|
||||||
|
|
||||||
active_check_parser = (
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
|
||||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
class ActivateCheckQuery(BaseModel):
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
workspace_id: str | None = Field(default=None)
|
||||||
)
|
email: EmailStr | None = Field(default=None)
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
class ActivatePayload(BaseModel):
|
||||||
|
workspace_id: str | None = Field(default=None)
|
||||||
|
email: EmailStr | None = Field(default=None)
|
||||||
|
token: str
|
||||||
|
name: str = Field(..., max_length=30)
|
||||||
|
interface_language: str = Field(...)
|
||||||
|
timezone: str = Field(...)
|
||||||
|
|
||||||
|
@field_validator("interface_language")
|
||||||
|
@classmethod
|
||||||
|
def validate_lang(cls, value: str) -> str:
|
||||||
|
return supported_language(value)
|
||||||
|
|
||||||
|
@field_validator("timezone")
|
||||||
|
@classmethod
|
||||||
|
def validate_tz(cls, value: str) -> str:
|
||||||
|
return timezone(value)
|
||||||
|
|
||||||
|
|
||||||
|
for model in (ActivateCheckQuery, ActivatePayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate/check")
|
@console_ns.route("/activate/check")
|
||||||
class ActivateCheckApi(Resource):
|
class ActivateCheckApi(Resource):
|
||||||
@console_ns.doc("check_activation_token")
|
@console_ns.doc("check_activation_token")
|
||||||
@console_ns.doc(description="Check if activation token is valid")
|
@console_ns.doc(description="Check if activation token is valid")
|
||||||
@console_ns.expect(active_check_parser)
|
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Success",
|
"Success",
|
||||||
|
|
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def get(self):
|
def get(self):
|
||||||
args = active_check_parser.parse_args()
|
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
workspaceId = args["workspace_id"]
|
workspaceId = args.workspace_id
|
||||||
reg_email = args["email"]
|
reg_email = args.email
|
||||||
token = args["token"]
|
token = args.token
|
||||||
|
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||||
if invitation:
|
if invitation:
|
||||||
|
|
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
|
||||||
return {"is_valid": False}
|
return {"is_valid": False}
|
||||||
|
|
||||||
|
|
||||||
active_parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
|
||||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate")
|
@console_ns.route("/activate")
|
||||||
class ActivateApi(Resource):
|
class ActivateApi(Resource):
|
||||||
@console_ns.doc("activate_account")
|
@console_ns.doc("activate_account")
|
||||||
@console_ns.doc(description="Activate account with invitation token")
|
@console_ns.doc(description="Activate account with invitation token")
|
||||||
@console_ns.expect(active_parser)
|
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Account activated successfully",
|
"Account activated successfully",
|
||||||
|
|
@ -79,30 +93,27 @@ class ActivateApi(Resource):
|
||||||
"ActivationResponse",
|
"ActivationResponse",
|
||||||
{
|
{
|
||||||
"result": fields.String(description="Operation result"),
|
"result": fields.String(description="Operation result"),
|
||||||
"data": fields.Raw(description="Login token data"),
|
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@console_ns.response(400, "Already activated or invalid token")
|
@console_ns.response(400, "Already activated or invalid token")
|
||||||
def post(self):
|
def post(self):
|
||||||
args = active_parser.parse_args()
|
args = ActivatePayload.model_validate(console_ns.payload)
|
||||||
|
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||||
if invitation is None:
|
if invitation is None:
|
||||||
raise AlreadyActivateError()
|
raise AlreadyActivateError()
|
||||||
|
|
||||||
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||||
|
|
||||||
account = invitation["account"]
|
account = invitation["account"]
|
||||||
account.name = args["name"]
|
account.name = args.name
|
||||||
|
|
||||||
account.interface_language = args["interface_language"]
|
account.interface_language = args.interface_language
|
||||||
account.timezone = args["timezone"]
|
account.timezone = args.timezone
|
||||||
account.interface_theme = "light"
|
account.interface_theme = "light"
|
||||||
account.status = AccountStatus.ACTIVE
|
account.status = AccountStatus.ACTIVE
|
||||||
account.initialized_at = naive_utc_now()
|
account.initialized_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
return {"result": "success"}
|
||||||
|
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,26 @@
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from controllers.console import console_ns
|
|
||||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
|
||||||
from controllers.console.wraps import is_admin_or_owner_required
|
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from .. import console_ns
|
||||||
|
from ..auth.error import ApiKeyAuthFailedError
|
||||||
|
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthBindingPayload(BaseModel):
|
||||||
|
category: str = Field(...)
|
||||||
|
provider: str = Field(...)
|
||||||
|
credentials: dict = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
ApiKeyAuthBindingPayload.__name__,
|
||||||
|
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/api-key-auth/data-source")
|
@console_ns.route("/api-key-auth/data-source")
|
||||||
|
|
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@is_admin_or_owner_required
|
@is_admin_or_owner_required
|
||||||
|
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
data = payload.model_dump()
|
||||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
|
||||||
try:
|
try:
|
||||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ApiKeyAuthFailedError(str(e))
|
raise ApiKeyAuthFailedError(str(e))
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,11 @@ from flask import current_app, redirect, request
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import console_ns
|
|
||||||
from controllers.console.wraps import is_admin_or_owner_required
|
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from libs.oauth_data_source import NotionOAuth
|
from libs.oauth_data_source import NotionOAuth
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from .. import console_ns
|
||||||
|
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
|
||||||
InvalidTokenError,
|
InvalidTokenError,
|
||||||
PasswordMismatchError,
|
PasswordMismatchError,
|
||||||
)
|
)
|
||||||
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
|
|
||||||
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models import Account
|
from models import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||||
|
|
||||||
|
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||||
|
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRegisterSendPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(..., description="Email address")
|
||||||
|
language: str | None = Field(default=None, description="Language code")
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRegisterValidityPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
code: str = Field(...)
|
||||||
|
token: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRegisterResetPayload(BaseModel):
|
||||||
|
token: str = Field(...)
|
||||||
|
new_password: str = Field(...)
|
||||||
|
password_confirm: str = Field(...)
|
||||||
|
|
||||||
|
@field_validator("new_password", "password_confirm")
|
||||||
|
@classmethod
|
||||||
|
def validate_password(cls, value: str) -> str:
|
||||||
|
return valid_password(value)
|
||||||
|
|
||||||
|
|
||||||
|
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/email-register/send-email")
|
@console_ns.route("/email-register/send-email")
|
||||||
class EmailRegisterSendEmailApi(Resource):
|
class EmailRegisterSendEmailApi(Resource):
|
||||||
|
|
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
if args["language"] in languages:
|
if args.language in languages:
|
||||||
language = args["language"]
|
language = args.language
|
||||||
|
|
||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||||
token = None
|
token = None
|
||||||
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
|
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args.email
|
||||||
|
|
||||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
|
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||||
if is_email_register_error_rate_limit:
|
if is_email_register_error_rate_limit:
|
||||||
raise EmailRegisterLimitError()
|
raise EmailRegisterLimitError()
|
||||||
|
|
||||||
token_data = AccountService.get_email_register_data(args["token"])
|
token_data = AccountService.get_email_register_data(args.token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if user_email != token_data.get("email"):
|
if user_email != token_data.get("email"):
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if args["code"] != token_data.get("code"):
|
if args.code != token_data.get("code"):
|
||||||
AccountService.add_email_register_error_rate_limit(args["email"])
|
AccountService.add_email_register_error_rate_limit(args.email)
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
# Verified, revoke the first token
|
# Verified, revoke the first token
|
||||||
AccountService.revoke_email_register_token(args["token"])
|
AccountService.revoke_email_register_token(args.token)
|
||||||
|
|
||||||
# Refresh token data by generating a new token
|
# Refresh token data by generating a new token
|
||||||
_, new_token = AccountService.generate_email_register_token(
|
_, new_token = AccountService.generate_email_register_token(
|
||||||
user_email, code=args["code"], additional_data={"phase": "register"}
|
user_email, code=args.code, additional_data={"phase": "register"}
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.reset_email_register_error_rate_limit(args["email"])
|
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
if args["new_password"] != args["password_confirm"]:
|
if args.new_password != args.password_confirm:
|
||||||
raise PasswordMismatchError()
|
raise PasswordMismatchError()
|
||||||
|
|
||||||
# Validate token and get register data
|
# Validate token and get register data
|
||||||
register_data = AccountService.get_email_register_data(args["token"])
|
register_data = AccountService.get_email_register_data(args.token)
|
||||||
if not register_data:
|
if not register_data:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
# Must use token in reset phase
|
# Must use token in reset phase
|
||||||
|
|
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
# Revoke token to prevent reuse
|
# Revoke token to prevent reuse
|
||||||
AccountService.revoke_email_register_token(args["token"])
|
AccountService.revoke_email_register_token(args.token)
|
||||||
|
|
||||||
email = register_data.get("email", "")
|
email = register_data.get("email", "")
|
||||||
|
|
||||||
|
|
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
|
||||||
if account:
|
if account:
|
||||||
raise EmailAlreadyInUseError()
|
raise EmailAlreadyInUseError()
|
||||||
else:
|
else:
|
||||||
account = self._create_new_account(email, args["password_confirm"])
|
account = self._create_new_account(email, args.password_confirm)
|
||||||
if not account:
|
if not account:
|
||||||
raise AccountNotFoundError()
|
raise AccountNotFoundError()
|
||||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ import base64
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models import Account
|
from models import Account
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordSendPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordCheckPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
code: str = Field(...)
|
||||||
|
token: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordResetPayload(BaseModel):
|
||||||
|
token: str = Field(...)
|
||||||
|
new_password: str = Field(...)
|
||||||
|
password_confirm: str = Field(...)
|
||||||
|
|
||||||
|
@field_validator("new_password", "password_confirm")
|
||||||
|
@classmethod
|
||||||
|
def validate_password(cls, value: str) -> str:
|
||||||
|
return valid_password(value)
|
||||||
|
|
||||||
|
|
||||||
|
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/forgot-password")
|
@console_ns.route("/forgot-password")
|
||||||
class ForgotPasswordSendEmailApi(Resource):
|
class ForgotPasswordSendEmailApi(Resource):
|
||||||
@console_ns.doc("send_forgot_password_email")
|
@console_ns.doc("send_forgot_password_email")
|
||||||
@console_ns.doc(description="Send password reset email")
|
@console_ns.doc(description="Send password reset email")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"ForgotPasswordEmailRequest",
|
|
||||||
{
|
|
||||||
"email": fields.String(required=True, description="Email address"),
|
|
||||||
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Email sent successfully",
|
"Email sent successfully",
|
||||||
|
|
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if args.language is not None and args.language == "zh-Hans":
|
||||||
language = "zh-Hans"
|
language = "zh-Hans"
|
||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||||
|
|
||||||
token = AccountService.send_reset_password_email(
|
token = AccountService.send_reset_password_email(
|
||||||
account=account,
|
account=account,
|
||||||
email=args["email"],
|
email=args.email,
|
||||||
language=language,
|
language=language,
|
||||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||||
)
|
)
|
||||||
|
|
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||||
class ForgotPasswordCheckApi(Resource):
|
class ForgotPasswordCheckApi(Resource):
|
||||||
@console_ns.doc("check_forgot_password_code")
|
@console_ns.doc("check_forgot_password_code")
|
||||||
@console_ns.doc(description="Verify password reset code")
|
@console_ns.doc(description="Verify password reset code")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"ForgotPasswordCheckRequest",
|
|
||||||
{
|
|
||||||
"email": fields.String(required=True, description="Email address"),
|
|
||||||
"code": fields.String(required=True, description="Verification code"),
|
|
||||||
"token": fields.String(required=True, description="Reset token"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Code verified successfully",
|
"Code verified successfully",
|
||||||
|
|
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args.email
|
||||||
|
|
||||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||||
if is_forgot_password_error_rate_limit:
|
if is_forgot_password_error_rate_limit:
|
||||||
raise EmailPasswordResetLimitError()
|
raise EmailPasswordResetLimitError()
|
||||||
|
|
||||||
token_data = AccountService.get_reset_password_data(args["token"])
|
token_data = AccountService.get_reset_password_data(args.token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if user_email != token_data.get("email"):
|
if user_email != token_data.get("email"):
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if args["code"] != token_data.get("code"):
|
if args.code != token_data.get("code"):
|
||||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
# Verified, revoke the first token
|
# Verified, revoke the first token
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(args.token)
|
||||||
|
|
||||||
# Refresh token data by generating a new token
|
# Refresh token data by generating a new token
|
||||||
_, new_token = AccountService.generate_reset_password_token(
|
_, new_token = AccountService.generate_reset_password_token(
|
||||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
|
||||||
class ForgotPasswordResetApi(Resource):
|
class ForgotPasswordResetApi(Resource):
|
||||||
@console_ns.doc("reset_password")
|
@console_ns.doc("reset_password")
|
||||||
@console_ns.doc(description="Reset password with verification token")
|
@console_ns.doc(description="Reset password with verification token")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"ForgotPasswordResetRequest",
|
|
||||||
{
|
|
||||||
"token": fields.String(required=True, description="Verification token"),
|
|
||||||
"new_password": fields.String(required=True, description="New password"),
|
|
||||||
"password_confirm": fields.String(required=True, description="Password confirmation"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Password reset successfully",
|
"Password reset successfully",
|
||||||
|
|
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
if args["new_password"] != args["password_confirm"]:
|
if args.new_password != args.password_confirm:
|
||||||
raise PasswordMismatchError()
|
raise PasswordMismatchError()
|
||||||
|
|
||||||
# Validate token and get reset data
|
# Validate token and get reset data
|
||||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
reset_data = AccountService.get_reset_password_data(args.token)
|
||||||
if not reset_data:
|
if not reset_data:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
# Must use token in reset phase
|
# Must use token in reset phase
|
||||||
|
|
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
# Revoke token to prevent reuse
|
# Revoke token to prevent reuse
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(args.token)
|
||||||
|
|
||||||
# Generate secure salt and hash password
|
# Generate secure salt and hash password
|
||||||
salt = secrets.token_bytes(16)
|
salt = secrets.token_bytes(16)
|
||||||
password_hashed = hash_password(args["new_password"], salt)
|
password_hashed = hash_password(args.new_password, salt)
|
||||||
|
|
||||||
email = reset_data.get("email", "")
|
email = reset_data.get("email", "")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import make_response, request
|
from flask import make_response, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
@ -21,9 +22,14 @@ from controllers.console.error import (
|
||||||
NotAllowedCreateWorkspace,
|
NotAllowedCreateWorkspace,
|
||||||
WorkspacesLimitExceeded,
|
WorkspacesLimitExceeded,
|
||||||
)
|
)
|
||||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
from controllers.console.wraps import (
|
||||||
|
decrypt_code_field,
|
||||||
|
decrypt_password_field,
|
||||||
|
email_password_login_enabled,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from libs.token import (
|
from libs.token import (
|
||||||
clear_access_token_from_cookie,
|
clear_access_token_from_cookie,
|
||||||
|
|
@ -40,6 +46,36 @@ from services.errors.account import AccountRegisterError
|
||||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class LoginPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(..., description="Email address")
|
||||||
|
password: str = Field(..., description="Password")
|
||||||
|
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||||
|
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||||
|
|
||||||
|
|
||||||
|
class EmailPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailCodeLoginPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
code: str = Field(...)
|
||||||
|
token: str = Field(...)
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def reg(cls: type[BaseModel]):
|
||||||
|
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
reg(LoginPayload)
|
||||||
|
reg(EmailPayload)
|
||||||
|
reg(EmailCodeLoginPayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/login")
|
@console_ns.route("/login")
|
||||||
class LoginApi(Resource):
|
class LoginApi(Resource):
|
||||||
|
|
@ -47,41 +83,37 @@ class LoginApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||||
|
@decrypt_password_field
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Authenticate user and login."""
|
"""Authenticate user and login."""
|
||||||
parser = (
|
args = LoginPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("password", type=str, required=True, location="json")
|
|
||||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
|
||||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
|
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||||
if is_login_error_rate_limit:
|
if is_login_error_rate_limit:
|
||||||
raise EmailPasswordLoginLimitError()
|
raise EmailPasswordLoginLimitError()
|
||||||
|
|
||||||
invitation = args["invite_token"]
|
# TODO: why invitation is re-assigned with different type?
|
||||||
|
invitation = args.invite_token # type: ignore
|
||||||
if invitation:
|
if invitation:
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
|
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if invitation:
|
if invitation:
|
||||||
data = invitation.get("data", {})
|
data = invitation.get("data", {}) # type: ignore
|
||||||
invitee_email = data.get("email") if data else None
|
invitee_email = data.get("email") if data else None
|
||||||
if invitee_email != args["email"]:
|
if invitee_email != args.email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
|
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||||
else:
|
else:
|
||||||
account = AccountService.authenticate(args["email"], args["password"])
|
account = AccountService.authenticate(args.email, args.password)
|
||||||
except services.errors.account.AccountLoginError:
|
except services.errors.account.AccountLoginError:
|
||||||
raise AccountBannedError()
|
raise AccountBannedError()
|
||||||
except services.errors.account.AccountPasswordError:
|
except services.errors.account.AccountPasswordError:
|
||||||
AccountService.add_login_error_rate_limit(args["email"])
|
AccountService.add_login_error_rate_limit(args.email)
|
||||||
raise AuthenticationFailedError()
|
raise AuthenticationFailedError()
|
||||||
# SELF_HOSTED only have one workspace
|
# SELF_HOSTED only have one workspace
|
||||||
tenants = TenantService.get_join_tenants(account)
|
tenants = TenantService.get_join_tenants(account)
|
||||||
|
|
@ -97,7 +129,7 @@ class LoginApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(args["email"])
|
AccountService.reset_login_error_rate_limit(args.email)
|
||||||
|
|
||||||
# Create response with cookies instead of returning tokens in body
|
# Create response with cookies instead of returning tokens in body
|
||||||
response = make_response({"result": "success"})
|
response = make_response({"result": "success"})
|
||||||
|
|
@ -134,25 +166,21 @@ class LogoutApi(Resource):
|
||||||
class ResetPasswordSendEmailApi(Resource):
|
class ResetPasswordSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if args.language is not None and args.language == "zh-Hans":
|
||||||
language = "zh-Hans"
|
language = "zh-Hans"
|
||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(args["email"])
|
account = AccountService.get_user_through_email(args.email)
|
||||||
except AccountRegisterError:
|
except AccountRegisterError:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
token = AccountService.send_reset_password_email(
|
token = AccountService.send_reset_password_email(
|
||||||
email=args["email"],
|
email=args.email,
|
||||||
account=account,
|
account=account,
|
||||||
language=language,
|
language=language,
|
||||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||||
|
|
@ -164,30 +192,26 @@ class ResetPasswordSendEmailApi(Resource):
|
||||||
@console_ns.route("/email-code-login")
|
@console_ns.route("/email-code-login")
|
||||||
class EmailCodeLoginSendEmailApi(Resource):
|
class EmailCodeLoginSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if args.language is not None and args.language == "zh-Hans":
|
||||||
language = "zh-Hans"
|
language = "zh-Hans"
|
||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(args["email"])
|
account = AccountService.get_user_through_email(args.email)
|
||||||
except AccountRegisterError:
|
except AccountRegisterError:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
if account is None:
|
if account is None:
|
||||||
if FeatureService.get_system_features().is_allow_register:
|
if FeatureService.get_system_features().is_allow_register:
|
||||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||||
else:
|
else:
|
||||||
raise AccountNotFound()
|
raise AccountNotFound()
|
||||||
else:
|
else:
|
||||||
|
|
@ -199,30 +223,25 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||||
@console_ns.route("/email-code-login/validity")
|
@console_ns.route("/email-code-login/validity")
|
||||||
class EmailCodeLoginApi(Resource):
|
class EmailCodeLoginApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||||
|
@decrypt_code_field
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args.email
|
||||||
language = args["language"]
|
language = args.language
|
||||||
|
|
||||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
token_data = AccountService.get_email_code_login_data(args.token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if token_data["email"] != args["email"]:
|
if token_data["email"] != args.email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if token_data["code"] != args["code"]:
|
if token_data["code"] != args.code:
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
AccountService.revoke_email_code_login_token(args["token"])
|
AccountService.revoke_email_code_login_token(args.token)
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(user_email)
|
account = AccountService.get_user_through_email(user_email)
|
||||||
except AccountRegisterError:
|
except AccountRegisterError:
|
||||||
|
|
@ -255,7 +274,7 @@ class EmailCodeLoginApi(Resource):
|
||||||
except WorkspacesLimitExceededError:
|
except WorkspacesLimitExceededError:
|
||||||
raise WorkspacesLimitExceeded()
|
raise WorkspacesLimitExceeded()
|
||||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(args["email"])
|
AccountService.reset_login_error_rate_limit(args.email)
|
||||||
|
|
||||||
# Create response with cookies instead of returning tokens in body
|
# Create response with cookies instead of returning tokens in body
|
||||||
response = make_response({"result": "success"})
|
response = make_response({"result": "success"})
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@ from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar
|
from typing import Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
|
@ -20,15 +21,34 @@ R = TypeVar("R")
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientPayload(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderRequest(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
redirect_uri: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthTokenRequest(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
grant_type: str
|
||||||
|
code: str | None = None
|
||||||
|
client_secret: str | None = None
|
||||||
|
redirect_uri: str | None = None
|
||||||
|
refresh_token: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
json_data = request.get_json()
|
||||||
parsed_args = parser.parse_args()
|
if json_data is None:
|
||||||
client_id = parsed_args.get("client_id")
|
|
||||||
if not client_id:
|
|
||||||
raise BadRequest("client_id is required")
|
raise BadRequest("client_id is required")
|
||||||
|
|
||||||
|
payload = OAuthClientPayload.model_validate(json_data)
|
||||||
|
client_id = payload.client_id
|
||||||
|
|
||||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||||
if not oauth_provider_app:
|
if not oauth_provider_app:
|
||||||
raise NotFound("client_id is invalid")
|
raise NotFound("client_id is invalid")
|
||||||
|
|
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
payload = OAuthProviderRequest.model_validate(request.get_json())
|
||||||
parsed_args = parser.parse_args()
|
redirect_uri = payload.redirect_uri
|
||||||
redirect_uri = parsed_args.get("redirect_uri")
|
|
||||||
|
|
||||||
# check if redirect_uri is valid
|
# check if redirect_uri is valid
|
||||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||||
|
|
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
parser = (
|
payload = OAuthTokenRequest.model_validate(request.get_json())
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("grant_type", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=False, location="json")
|
|
||||||
.add_argument("client_secret", type=str, required=False, location="json")
|
|
||||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
|
||||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
parsed_args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
grant_type = OAuthGrantType(payload.grant_type)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise BadRequest("invalid grant_type")
|
raise BadRequest("invalid grant_type")
|
||||||
|
|
||||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||||
if not parsed_args["code"]:
|
if not payload.code:
|
||||||
raise BadRequest("code is required")
|
raise BadRequest("code is required")
|
||||||
|
|
||||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
if payload.client_secret != oauth_provider_app.client_secret:
|
||||||
raise BadRequest("client_secret is invalid")
|
raise BadRequest("client_secret is invalid")
|
||||||
|
|
||||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||||
raise BadRequest("redirect_uri is invalid")
|
raise BadRequest("redirect_uri is invalid")
|
||||||
|
|
||||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||||
)
|
)
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
{
|
{
|
||||||
|
|
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||||
if not parsed_args["refresh_token"]:
|
if not payload.refresh_token:
|
||||||
raise BadRequest("refresh_token is required")
|
raise BadRequest("refresh_token is required")
|
||||||
|
|
||||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||||
)
|
)
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
import base64
|
import base64
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -9,6 +12,21 @@ from enums.cloud_plan import CloudPlan
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionQuery(BaseModel):
|
||||||
|
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||||
|
interval: Literal["month", "year"] = Field(..., description="Billing interval")
|
||||||
|
|
||||||
|
|
||||||
|
class PartnerTenantsPayload(BaseModel):
|
||||||
|
click_id: str = Field(..., description="Click Id from partner referral link")
|
||||||
|
|
||||||
|
|
||||||
|
for model in (SubscriptionQuery, PartnerTenantsPayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/billing/subscription")
|
@console_ns.route("/billing/subscription")
|
||||||
class Subscription(Resource):
|
class Subscription(Resource):
|
||||||
|
|
@ -18,20 +36,9 @@ class Subscription(Resource):
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"plan",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
location="args",
|
|
||||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
|
||||||
)
|
|
||||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/billing/invoices")
|
@console_ns.route("/billing/invoices")
|
||||||
|
|
@ -65,11 +72,10 @@ class PartnerTenants(Resource):
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def put(self, partner_key: str):
|
def put(self, partner_key: str):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
click_id = args["click_id"]
|
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||||
|
click_id = args.click_id
|
||||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||||
except Exception:
|
except Exception:
|
||||||
raise BadRequest("Invalid partner_key")
|
raise BadRequest("Invalid partner_key")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
|
|
@ -9,16 +10,28 @@ from .. import console_ns
|
||||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceDownloadQuery(BaseModel):
|
||||||
|
doc_name: str = Field(..., description="Compliance document name")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
ComplianceDownloadQuery.__name__,
|
||||||
|
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/compliance/download")
|
@console_ns.route("/compliance/download")
|
||||||
class ComplianceApi(Resource):
|
class ComplianceApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
|
||||||
|
@console_ns.doc("download_compliance_document")
|
||||||
|
@console_ns.doc(description="Get compliance document download link")
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
|
||||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
|
|
@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||||
|
|
||||||
|
from .. import console_ns
|
||||||
|
from ..wraps import account_initialization_required, setup_required
|
||||||
|
|
||||||
|
|
||||||
|
class NotionEstimatePayload(BaseModel):
|
||||||
|
notion_info_list: list[dict[str, Any]]
|
||||||
|
process_rule: dict[str, Any]
|
||||||
|
doc_form: str = Field(default="text_model")
|
||||||
|
doc_language: str = Field(default="English")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_model(console_ns, NotionEstimatePayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/data-source/integrates",
|
"/data-source/integrates",
|
||||||
|
|
@ -127,6 +140,18 @@ class DataSourceNotionListApi(Resource):
|
||||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
if not credential_id:
|
if not credential_id:
|
||||||
raise ValueError("Credential id is required.")
|
raise ValueError("Credential id is required.")
|
||||||
|
|
||||||
|
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
|
||||||
|
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
|
||||||
|
datasource_parameters = {}
|
||||||
|
if datasource_parameters_str:
|
||||||
|
try:
|
||||||
|
datasource_parameters = json.loads(datasource_parameters_str)
|
||||||
|
if not isinstance(datasource_parameters, dict):
|
||||||
|
raise ValueError("datasource_parameters must be a JSON object.")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError("Invalid datasource_parameters JSON format.")
|
||||||
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
credential = datasource_provider_service.get_datasource_credentials(
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
|
|
@ -174,7 +199,7 @@ class DataSourceNotionListApi(Resource):
|
||||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||||
datasource_runtime.get_online_document_pages(
|
datasource_runtime.get_online_document_pages(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
datasource_parameters={},
|
datasource_parameters=datasource_parameters,
|
||||||
provider_type=datasource_runtime.datasource_provider_type(),
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -205,14 +230,14 @@ class DataSourceNotionListApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||||
"/datasets/notion-indexing-estimate",
|
"/datasets/notion-indexing-estimate",
|
||||||
)
|
)
|
||||||
class DataSourceNotionApi(Resource):
|
class DataSourceNotionApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, workspace_id, page_id, page_type):
|
def get(self, page_id, page_type):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
|
|
@ -226,11 +251,10 @@ class DataSourceNotionApi(Resource):
|
||||||
plugin_id="langgenius/notion_datasource",
|
plugin_id="langgenius/notion_datasource",
|
||||||
)
|
)
|
||||||
|
|
||||||
workspace_id = str(workspace_id)
|
|
||||||
page_id = str(page_id)
|
page_id = str(page_id)
|
||||||
|
|
||||||
extractor = NotionExtractor(
|
extractor = NotionExtractor(
|
||||||
notion_workspace_id=workspace_id,
|
notion_workspace_id="",
|
||||||
notion_obj_id=page_id,
|
notion_obj_id=page_id,
|
||||||
notion_page_type=page_type,
|
notion_page_type=page_type,
|
||||||
notion_access_token=credential.get("integration_secret"),
|
notion_access_token=credential.get("integration_secret"),
|
||||||
|
|
@ -243,20 +267,15 @@ class DataSourceNotionApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = (
|
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
args = payload.model_dump()
|
||||||
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
|
||||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
|
||||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
|
||||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.estimate_args_validate(args)
|
DocumentService.estimate_args_validate(args)
|
||||||
notion_info_list = args["notion_info_list"]
|
notion_info_list = payload.notion_info_list
|
||||||
extract_settings = []
|
extract_settings = []
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info["workspace_id"]
|
workspace_id = notion_info["workspace_id"]
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.apikey import (
|
from controllers.console.apikey import (
|
||||||
api_key_item_model,
|
api_key_item_model,
|
||||||
|
|
@ -48,7 +50,6 @@ from fields.dataset_fields import (
|
||||||
)
|
)
|
||||||
from fields.document_fields import document_status_fields
|
from fields.document_fields import document_status_fields
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from libs.validators import validate_description_length
|
|
||||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
|
|
@ -107,10 +108,75 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
|
||||||
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
|
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name: str) -> str:
|
def _validate_indexing_technique(value: str | None) -> str | None:
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
if value is None:
|
||||||
raise ValueError("Name must be between 1 to 40 characters.")
|
return value
|
||||||
return name
|
if value not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||||
|
raise ValueError("Invalid indexing technique.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetCreatePayload(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1, max_length=40)
|
||||||
|
description: str = Field("", max_length=400)
|
||||||
|
indexing_technique: str | None = None
|
||||||
|
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
|
||||||
|
provider: str = "vendor"
|
||||||
|
external_knowledge_api_id: str | None = None
|
||||||
|
external_knowledge_id: str | None = None
|
||||||
|
|
||||||
|
@field_validator("indexing_technique")
|
||||||
|
@classmethod
|
||||||
|
def validate_indexing(cls, value: str | None) -> str | None:
|
||||||
|
return _validate_indexing_technique(value)
|
||||||
|
|
||||||
|
@field_validator("provider")
|
||||||
|
@classmethod
|
||||||
|
def validate_provider(cls, value: str) -> str:
|
||||||
|
if value not in Dataset.PROVIDER_LIST:
|
||||||
|
raise ValueError("Invalid provider.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetUpdatePayload(BaseModel):
|
||||||
|
name: str | None = Field(None, min_length=1, max_length=40)
|
||||||
|
description: str | None = Field(None, max_length=400)
|
||||||
|
permission: DatasetPermissionEnum | None = None
|
||||||
|
indexing_technique: str | None = None
|
||||||
|
embedding_model: str | None = None
|
||||||
|
embedding_model_provider: str | None = None
|
||||||
|
retrieval_model: dict[str, Any] | None = None
|
||||||
|
partial_member_list: list[dict[str, str]] | None = None
|
||||||
|
external_retrieval_model: dict[str, Any] | None = None
|
||||||
|
external_knowledge_id: str | None = None
|
||||||
|
external_knowledge_api_id: str | None = None
|
||||||
|
icon_info: dict[str, Any] | None = None
|
||||||
|
is_multimodal: bool | None = False
|
||||||
|
|
||||||
|
@field_validator("indexing_technique")
|
||||||
|
@classmethod
|
||||||
|
def validate_indexing(cls, value: str | None) -> str | None:
|
||||||
|
return _validate_indexing_technique(value)
|
||||||
|
|
||||||
|
|
||||||
|
class IndexingEstimatePayload(BaseModel):
|
||||||
|
info_list: dict[str, Any]
|
||||||
|
process_rule: dict[str, Any]
|
||||||
|
indexing_technique: str
|
||||||
|
doc_form: str = "text_model"
|
||||||
|
dataset_id: str | None = None
|
||||||
|
doc_language: str = "English"
|
||||||
|
|
||||||
|
@field_validator("indexing_technique")
|
||||||
|
@classmethod
|
||||||
|
def validate_indexing(cls, value: str) -> str:
|
||||||
|
result = _validate_indexing_technique(value)
|
||||||
|
if result is None:
|
||||||
|
raise ValueError("indexing_technique is required.")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
|
||||||
|
|
||||||
|
|
||||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||||
|
|
@ -157,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||||
VectorType.COUCHBASE,
|
VectorType.COUCHBASE,
|
||||||
VectorType.OPENGAUSS,
|
VectorType.OPENGAUSS,
|
||||||
VectorType.OCEANBASE,
|
VectorType.OCEANBASE,
|
||||||
|
VectorType.SEEKDB,
|
||||||
VectorType.TABLESTORE,
|
VectorType.TABLESTORE,
|
||||||
VectorType.HUAWEI_CLOUD,
|
VectorType.HUAWEI_CLOUD,
|
||||||
VectorType.TENCENT,
|
VectorType.TENCENT,
|
||||||
|
|
@ -164,6 +231,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||||
VectorType.CLICKZETTA,
|
VectorType.CLICKZETTA,
|
||||||
VectorType.BAIDU,
|
VectorType.BAIDU,
|
||||||
VectorType.ALIBABACLOUD_MYSQL,
|
VectorType.ALIBABACLOUD_MYSQL,
|
||||||
|
VectorType.IRIS,
|
||||||
}
|
}
|
||||||
|
|
||||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||||
|
|
@ -255,20 +323,7 @@ class DatasetListApi(Resource):
|
||||||
|
|
||||||
@console_ns.doc("create_dataset")
|
@console_ns.doc("create_dataset")
|
||||||
@console_ns.doc(description="Create a new dataset")
|
@console_ns.doc(description="Create a new dataset")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CreateDatasetRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
|
|
||||||
"description": fields.String(description="Dataset description (max 400 characters)"),
|
|
||||||
"indexing_technique": fields.String(description="Indexing technique"),
|
|
||||||
"permission": fields.String(description="Dataset permission"),
|
|
||||||
"provider": fields.String(description="Provider"),
|
|
||||||
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
|
|
||||||
"external_knowledge_id": fields.String(description="External knowledge ID"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "Dataset created successfully")
|
@console_ns.response(201, "Dataset created successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -276,52 +331,7 @@ class DatasetListApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"name",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
help="type is required. Name must be between 1 to 40 characters.",
|
|
||||||
type=_validate_name,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"description",
|
|
||||||
type=validate_description_length,
|
|
||||||
nullable=True,
|
|
||||||
required=False,
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"indexing_technique",
|
|
||||||
type=str,
|
|
||||||
location="json",
|
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
|
||||||
nullable=True,
|
|
||||||
help="Invalid indexing technique.",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"external_knowledge_api_id",
|
|
||||||
type=str,
|
|
||||||
nullable=True,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"provider",
|
|
||||||
type=str,
|
|
||||||
nullable=True,
|
|
||||||
choices=Dataset.PROVIDER_LIST,
|
|
||||||
required=False,
|
|
||||||
default="vendor",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"external_knowledge_id",
|
|
||||||
type=str,
|
|
||||||
nullable=True,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
|
@ -331,14 +341,14 @@ class DatasetListApi(Resource):
|
||||||
try:
|
try:
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=payload.name,
|
||||||
description=args["description"],
|
description=payload.description,
|
||||||
indexing_technique=args["indexing_technique"],
|
indexing_technique=payload.indexing_technique,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
permission=DatasetPermissionEnum.ONLY_ME,
|
permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
|
||||||
provider=args["provider"],
|
provider=payload.provider,
|
||||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
external_knowledge_api_id=payload.external_knowledge_api_id,
|
||||||
external_knowledge_id=args["external_knowledge_id"],
|
external_knowledge_id=payload.external_knowledge_id,
|
||||||
)
|
)
|
||||||
except services.errors.dataset.DatasetNameDuplicateError:
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
raise DatasetNameDuplicateError()
|
raise DatasetNameDuplicateError()
|
||||||
|
|
@ -399,18 +409,7 @@ class DatasetApi(Resource):
|
||||||
|
|
||||||
@console_ns.doc("update_dataset")
|
@console_ns.doc("update_dataset")
|
||||||
@console_ns.doc(description="Update dataset details")
|
@console_ns.doc(description="Update dataset details")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"UpdateDatasetRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(description="Dataset name"),
|
|
||||||
"description": fields.String(description="Dataset description"),
|
|
||||||
"permission": fields.String(description="Dataset permission"),
|
|
||||||
"indexing_technique": fields.String(description="Indexing technique"),
|
|
||||||
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
||||||
@console_ns.response(404, "Dataset not found")
|
@console_ns.response(404, "Dataset not found")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
|
|
@ -424,93 +423,25 @@ class DatasetApi(Resource):
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
parser = (
|
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"name",
|
|
||||||
nullable=False,
|
|
||||||
help="type is required. Name must be between 1 to 40 characters.",
|
|
||||||
type=_validate_name,
|
|
||||||
)
|
|
||||||
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
|
||||||
.add_argument(
|
|
||||||
"indexing_technique",
|
|
||||||
type=str,
|
|
||||||
location="json",
|
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
|
||||||
nullable=True,
|
|
||||||
help="Invalid indexing technique.",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"permission",
|
|
||||||
type=str,
|
|
||||||
location="json",
|
|
||||||
choices=(
|
|
||||||
DatasetPermissionEnum.ONLY_ME,
|
|
||||||
DatasetPermissionEnum.ALL_TEAM,
|
|
||||||
DatasetPermissionEnum.PARTIAL_TEAM,
|
|
||||||
),
|
|
||||||
help="Invalid permission.",
|
|
||||||
)
|
|
||||||
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
|
||||||
.add_argument(
|
|
||||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
|
||||||
)
|
|
||||||
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
|
||||||
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
|
||||||
.add_argument(
|
|
||||||
"external_retrieval_model",
|
|
||||||
type=dict,
|
|
||||||
required=False,
|
|
||||||
nullable=True,
|
|
||||||
location="json",
|
|
||||||
help="Invalid external retrieval model.",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"external_knowledge_id",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
nullable=True,
|
|
||||||
location="json",
|
|
||||||
help="Invalid external knowledge id.",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"external_knowledge_api_id",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
nullable=True,
|
|
||||||
location="json",
|
|
||||||
help="Invalid external knowledge api id.",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"icon_info",
|
|
||||||
type=dict,
|
|
||||||
required=False,
|
|
||||||
nullable=True,
|
|
||||||
location="json",
|
|
||||||
help="Invalid icon info.",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
data = request.get_json()
|
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if (
|
if (
|
||||||
data.get("indexing_technique") == "high_quality"
|
payload.indexing_technique == "high_quality"
|
||||||
and data.get("embedding_model_provider") is not None
|
and payload.embedding_model_provider is not None
|
||||||
and data.get("embedding_model") is not None
|
and payload.embedding_model is not None
|
||||||
):
|
):
|
||||||
DatasetService.check_embedding_model_setting(
|
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
|
||||||
)
|
)
|
||||||
|
payload.is_multimodal = is_multimodal
|
||||||
|
payload_data = payload.model_dump(exclude_unset=True)
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
DatasetPermissionService.check_permission(
|
DatasetPermissionService.check_permission(
|
||||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
current_user, dataset, payload.permission, payload.partial_member_list
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
@ -518,15 +449,10 @@ class DatasetApi(Resource):
|
||||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||||
DatasetPermissionService.update_partial_member_list(
|
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
|
||||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
|
||||||
)
|
|
||||||
# clear partial member list when permission is only_me or all_team_members
|
# clear partial member list when permission is only_me or all_team_members
|
||||||
elif (
|
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
|
||||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
|
||||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
|
||||||
):
|
|
||||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||||
|
|
||||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||||
|
|
@ -615,24 +541,10 @@ class DatasetIndexingEstimateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
args = payload.model_dump()
|
||||||
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
|
||||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"indexing_technique",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
|
||||||
nullable=True,
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
|
||||||
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
|
||||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.estimate_args_validate(args)
|
DocumentService.estimate_args_validate(args)
|
||||||
|
|
|
||||||
|
|
@ -6,31 +6,14 @@ from typing import Literal, cast
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import asc, desc, select
|
from sqlalchemy import asc, desc, select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import (
|
|
||||||
ProviderModelCurrentlyNotSupportError,
|
|
||||||
ProviderNotInitializeError,
|
|
||||||
ProviderQuotaExceededError,
|
|
||||||
)
|
|
||||||
from controllers.console.datasets.error import (
|
|
||||||
ArchivedDocumentImmutableError,
|
|
||||||
DocumentAlreadyFinishedError,
|
|
||||||
DocumentIndexingError,
|
|
||||||
IndexingEstimateError,
|
|
||||||
InvalidActionError,
|
|
||||||
InvalidMetadataError,
|
|
||||||
)
|
|
||||||
from controllers.console.wraps import (
|
|
||||||
account_initialization_required,
|
|
||||||
cloud_edition_billing_rate_limit_check,
|
|
||||||
cloud_edition_billing_resource_check,
|
|
||||||
setup_required,
|
|
||||||
)
|
|
||||||
from core.errors.error import (
|
from core.errors.error import (
|
||||||
LLMBadRequestError,
|
LLMBadRequestError,
|
||||||
ModelCurrentlyNotSupportError,
|
ModelCurrentlyNotSupportError,
|
||||||
|
|
@ -55,10 +38,30 @@ from fields.document_fields import (
|
||||||
)
|
)
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||||
from models.dataset import DocumentPipelineExecutionLog
|
from models.dataset import DocumentPipelineExecutionLog
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||||
|
|
||||||
|
from ..app.error import (
|
||||||
|
ProviderModelCurrentlyNotSupportError,
|
||||||
|
ProviderNotInitializeError,
|
||||||
|
ProviderQuotaExceededError,
|
||||||
|
)
|
||||||
|
from ..datasets.error import (
|
||||||
|
ArchivedDocumentImmutableError,
|
||||||
|
DocumentAlreadyFinishedError,
|
||||||
|
DocumentIndexingError,
|
||||||
|
IndexingEstimateError,
|
||||||
|
InvalidActionError,
|
||||||
|
InvalidMetadataError,
|
||||||
|
)
|
||||||
|
from ..wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
cloud_edition_billing_rate_limit_check,
|
||||||
|
cloud_edition_billing_resource_check,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
|
||||||
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRetryPayload(BaseModel):
|
||||||
|
document_ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRenamePayload(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(
|
||||||
|
console_ns,
|
||||||
|
KnowledgeConfig,
|
||||||
|
ProcessRule,
|
||||||
|
RetrievalModel,
|
||||||
|
DocumentRetryPayload,
|
||||||
|
DocumentRenamePayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentResource(Resource):
|
class DocumentResource(Resource):
|
||||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
@ -201,8 +222,9 @@ class DatasetDocumentListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id: str):
|
def get(self, dataset_id):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
search = request.args.get("keyword", default=None, type=str)
|
search = request.args.get("keyword", default=None, type=str)
|
||||||
|
|
@ -310,6 +332,7 @@ class DatasetDocumentListApi(Resource):
|
||||||
@marshal_with(dataset_and_document_model)
|
@marshal_with(dataset_and_document_model)
|
||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
|
|
@ -328,23 +351,7 @@ class DatasetDocumentListApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
parser = (
|
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
.add_argument("data_source", type=dict, required=False, location="json")
|
|
||||||
.add_argument("process_rule", type=dict, required=False, location="json")
|
|
||||||
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
|
||||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
|
||||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
|
||||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
|
||||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
|
||||||
|
|
||||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||||
raise ValueError("indexing_technique is required.")
|
raise ValueError("indexing_technique is required.")
|
||||||
|
|
@ -390,17 +397,7 @@ class DatasetDocumentListApi(Resource):
|
||||||
class DatasetInitApi(Resource):
|
class DatasetInitApi(Resource):
|
||||||
@console_ns.doc("init_dataset")
|
@console_ns.doc("init_dataset")
|
||||||
@console_ns.doc(description="Initialize dataset with documents")
|
@console_ns.doc(description="Initialize dataset with documents")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||||
console_ns.model(
|
|
||||||
"DatasetInitRequest",
|
|
||||||
{
|
|
||||||
"upload_file_id": fields.String(required=True, description="Upload file ID"),
|
|
||||||
"indexing_technique": fields.String(description="Indexing technique"),
|
|
||||||
"process_rule": fields.Raw(description="Processing rules"),
|
|
||||||
"data_source": fields.Raw(description="Data source configuration"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
||||||
@console_ns.response(400, "Invalid request parameters")
|
@console_ns.response(400, "Invalid request parameters")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -415,27 +412,7 @@ class DatasetInitApi(Resource):
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = (
|
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"indexing_technique",
|
|
||||||
type=str,
|
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
|
||||||
required=True,
|
|
||||||
nullable=False,
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
|
||||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
|
||||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
|
||||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
|
||||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
|
||||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
|
||||||
if knowledge_config.indexing_technique == "high_quality":
|
if knowledge_config.indexing_technique == "high_quality":
|
||||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||||
|
|
@ -443,10 +420,14 @@ class DatasetInitApi(Resource):
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=args["embedding_model_provider"],
|
provider=knowledge_config.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=args["embedding_model"],
|
model=knowledge_config.embedding_model,
|
||||||
)
|
)
|
||||||
|
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||||
|
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||||
|
)
|
||||||
|
knowledge_config.is_multimodal = is_multimodal
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||||
|
|
@ -591,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||||
datasource_type=DatasourceType.NOTION,
|
datasource_type=DatasourceType.NOTION,
|
||||||
notion_info=NotionInfo.model_validate(
|
notion_info=NotionInfo.model_validate(
|
||||||
{
|
{
|
||||||
"credential_id": data_source_info["credential_id"],
|
"credential_id": data_source_info.get("credential_id"),
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
"notion_page_type": data_source_info["type"],
|
"notion_page_type": data_source_info["type"],
|
||||||
|
|
@ -1076,19 +1057,16 @@ class DocumentRetryApi(DocumentResource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
"""retry document."""
|
"""retry document."""
|
||||||
|
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
|
||||||
parser = reqparse.RequestParser().add_argument(
|
|
||||||
"document_ids", type=list, required=True, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
retry_documents = []
|
retry_documents = []
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
for document_id in args["document_ids"]:
|
for document_id in payload.document_ids:
|
||||||
try:
|
try:
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
|
|
||||||
|
|
@ -1121,6 +1099,7 @@ class DocumentRenameApi(DocumentResource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(document_fields)
|
@marshal_with(document_fields)
|
||||||
|
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
@ -1130,11 +1109,10 @@ class DocumentRenameApi(DocumentResource):
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
document = DocumentService.rename_document(dataset_id, document_id, args["name"])
|
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
|
||||||
except services.errors.document.DocumentIndexingError:
|
except services.errors.document.DocumentIndexingError:
|
||||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, marshal, reqparse
|
from flask_restx import Resource, marshal
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import ProviderNotInitializeError
|
from controllers.console.app.error import ProviderNotInitializeError
|
||||||
from controllers.console.datasets.error import (
|
from controllers.console.datasets.error import (
|
||||||
|
|
@ -36,6 +38,58 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentListQuery(BaseModel):
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
status: list[str] = Field(default_factory=list)
|
||||||
|
hit_count_gte: int | None = None
|
||||||
|
enabled: str = Field(default="all")
|
||||||
|
keyword: str | None = None
|
||||||
|
page: int = Field(default=1, ge=1)
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentCreatePayload(BaseModel):
|
||||||
|
content: str
|
||||||
|
answer: str | None = None
|
||||||
|
keywords: list[str] | None = None
|
||||||
|
attachment_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentUpdatePayload(BaseModel):
|
||||||
|
content: str
|
||||||
|
answer: str | None = None
|
||||||
|
keywords: list[str] | None = None
|
||||||
|
regenerate_child_chunks: bool = False
|
||||||
|
attachment_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BatchImportPayload(BaseModel):
|
||||||
|
upload_file_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChildChunkCreatePayload(BaseModel):
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChildChunkUpdatePayload(BaseModel):
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||||
|
chunks: list[ChildChunkUpdateArgs]
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(
|
||||||
|
console_ns,
|
||||||
|
SegmentListQuery,
|
||||||
|
SegmentCreatePayload,
|
||||||
|
SegmentUpdatePayload,
|
||||||
|
BatchImportPayload,
|
||||||
|
ChildChunkCreatePayload,
|
||||||
|
ChildChunkUpdatePayload,
|
||||||
|
ChildChunkBatchUpdatePayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||||
class DatasetDocumentSegmentListApi(Resource):
|
class DatasetDocumentSegmentListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -60,23 +114,18 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
parser = (
|
args = SegmentListQuery.model_validate(
|
||||||
reqparse.RequestParser()
|
{
|
||||||
.add_argument("limit", type=int, default=20, location="args")
|
**request.args.to_dict(),
|
||||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
"status": request.args.getlist("status"),
|
||||||
.add_argument("hit_count_gte", type=int, default=None, location="args")
|
}
|
||||||
.add_argument("enabled", type=str, default="all", location="args")
|
|
||||||
.add_argument("keyword", type=str, default=None, location="args")
|
|
||||||
.add_argument("page", type=int, default=1, location="args")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
page = args.page
|
||||||
|
limit = min(args.limit, 100)
|
||||||
page = args["page"]
|
status_list = args.status
|
||||||
limit = min(args["limit"], 100)
|
hit_count_gte = args.hit_count_gte
|
||||||
status_list = args["status"]
|
keyword = args.keyword
|
||||||
hit_count_gte = args["hit_count_gte"]
|
|
||||||
keyword = args["keyword"]
|
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
select(DocumentSegment)
|
select(DocumentSegment)
|
||||||
|
|
@ -96,10 +145,10 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||||
if keyword:
|
if keyword:
|
||||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||||
|
|
||||||
if args["enabled"].lower() != "all":
|
if args.enabled.lower() != "all":
|
||||||
if args["enabled"].lower() == "true":
|
if args.enabled.lower() == "true":
|
||||||
query = query.where(DocumentSegment.enabled == True)
|
query = query.where(DocumentSegment.enabled == True)
|
||||||
elif args["enabled"].lower() == "false":
|
elif args.enabled.lower() == "false":
|
||||||
query = query.where(DocumentSegment.enabled == False)
|
query = query.where(DocumentSegment.enabled == False)
|
||||||
|
|
||||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
|
|
@ -210,6 +259,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
|
@ -246,15 +296,10 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = (
|
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
payload_dict = payload.model_dump(exclude_none=True)
|
||||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
SegmentService.segment_create_args_validate(args, document)
|
|
||||||
segment = SegmentService.create_segment(args, document, dataset)
|
|
||||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -265,6 +310,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||||
def patch(self, dataset_id, document_id, segment_id):
|
def patch(self, dataset_id, document_id, segment_id):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
|
@ -313,18 +359,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = (
|
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
payload_dict = payload.model_dump(exclude_none=True)
|
||||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
segment = SegmentService.update_segment(
|
||||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||||
.add_argument(
|
|
||||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
|
||||||
SegmentService.segment_create_args_validate(args, document)
|
|
||||||
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
|
||||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -377,6 +417,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
|
@ -391,11 +432,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument(
|
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||||
"upload_file_id", type=str, required=True, nullable=False, location="json"
|
upload_file_id = payload.upload_file_id
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
upload_file_id = args["upload_file_id"]
|
|
||||||
|
|
||||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||||
if not upload_file:
|
if not upload_file:
|
||||||
|
|
@ -446,6 +484,7 @@ class ChildChunkAddApi(Resource):
|
||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||||
def post(self, dataset_id, document_id, segment_id):
|
def post(self, dataset_id, document_id, segment_id):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
|
@ -491,13 +530,9 @@ class ChildChunkAddApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser().add_argument(
|
|
||||||
"content", type=str, required=True, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
try:
|
try:
|
||||||
content = args["content"]
|
payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
|
||||||
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
|
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
|
||||||
except ChildChunkIndexingServiceError as e:
|
except ChildChunkIndexingServiceError as e:
|
||||||
raise ChildChunkIndexingError(str(e))
|
raise ChildChunkIndexingError(str(e))
|
||||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||||
|
|
@ -529,18 +564,17 @@ class ChildChunkAddApi(Resource):
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
parser = (
|
args = SegmentListQuery.model_validate(
|
||||||
reqparse.RequestParser()
|
{
|
||||||
.add_argument("limit", type=int, default=20, location="args")
|
"limit": request.args.get("limit", default=20, type=int),
|
||||||
.add_argument("keyword", type=str, default=None, location="args")
|
"keyword": request.args.get("keyword"),
|
||||||
.add_argument("page", type=int, default=1, location="args")
|
"page": request.args.get("page", default=1, type=int),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
page = args.page
|
||||||
|
limit = min(args.limit, 100)
|
||||||
page = args["page"]
|
keyword = args.keyword
|
||||||
limit = min(args["limit"], 100)
|
|
||||||
keyword = args["keyword"]
|
|
||||||
|
|
||||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||||
return {
|
return {
|
||||||
|
|
@ -588,14 +622,9 @@ class ChildChunkAddApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser().add_argument(
|
payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
"chunks", type=list, required=True, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
try:
|
try:
|
||||||
chunks_data = args["chunks"]
|
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
|
||||||
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
|
||||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
|
||||||
except ChildChunkIndexingServiceError as e:
|
except ChildChunkIndexingServiceError as e:
|
||||||
raise ChildChunkIndexingError(str(e))
|
raise ChildChunkIndexingError(str(e))
|
||||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||||
|
|
@ -665,6 +694,7 @@ class ChildChunkUpdateApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
|
@ -711,13 +741,9 @@ class ChildChunkUpdateApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser().add_argument(
|
|
||||||
"content", type=str, required=True, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
try:
|
try:
|
||||||
content = args["content"]
|
payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
|
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
|
||||||
except ChildChunkIndexingServiceError as e:
|
except ChildChunkIndexingServiceError as e:
|
||||||
raise ChildChunkIndexingError(str(e))
|
raise ChildChunkIndexingError(str(e))
|
||||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, marshal, reqparse
|
from flask_restx import Resource, fields, marshal
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
|
|
@ -71,10 +73,38 @@ except KeyError:
|
||||||
dataset_detail_model = _build_dataset_detail_model()
|
dataset_detail_model = _build_dataset_detail_model()
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name: str) -> str:
|
class ExternalKnowledgeApiPayload(BaseModel):
|
||||||
if not name or len(name) < 1 or len(name) > 100:
|
name: str = Field(..., min_length=1, max_length=40)
|
||||||
raise ValueError("Name must be between 1 to 100 characters.")
|
settings: dict[str, object]
|
||||||
return name
|
|
||||||
|
|
||||||
|
class ExternalDatasetCreatePayload(BaseModel):
|
||||||
|
external_knowledge_api_id: str
|
||||||
|
external_knowledge_id: str
|
||||||
|
name: str = Field(..., min_length=1, max_length=40)
|
||||||
|
description: str | None = Field(None, max_length=400)
|
||||||
|
external_retrieval_model: dict[str, object] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalHitTestingPayload(BaseModel):
|
||||||
|
query: str
|
||||||
|
external_retrieval_model: dict[str, object] | None = None
|
||||||
|
metadata_filtering_conditions: dict[str, object] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockRetrievalPayload(BaseModel):
|
||||||
|
retrieval_setting: dict[str, object]
|
||||||
|
query: str
|
||||||
|
knowledge_id: str
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(
|
||||||
|
console_ns,
|
||||||
|
ExternalKnowledgeApiPayload,
|
||||||
|
ExternalDatasetCreatePayload,
|
||||||
|
ExternalHitTestingPayload,
|
||||||
|
BedrockRetrievalPayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/external-knowledge-api")
|
@console_ns.route("/datasets/external-knowledge-api")
|
||||||
|
|
@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"name",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
help="Name is required. Name must be between 1 to 100 characters.",
|
|
||||||
type=_validate_name,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"settings",
|
|
||||||
type=dict,
|
|
||||||
location="json",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ExternalDatasetService.validate_api_list(args["settings"])
|
ExternalDatasetService.validate_api_list(payload.settings)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
|
|
@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||||
tenant_id=current_tenant_id, user_id=current_user.id, args=args
|
tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
|
||||||
)
|
)
|
||||||
except services.errors.dataset.DatasetNameDuplicateError:
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
raise DatasetNameDuplicateError()
|
raise DatasetNameDuplicateError()
|
||||||
|
|
@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||||
def patch(self, external_knowledge_api_id):
|
def patch(self, external_knowledge_api_id):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||||
|
|
||||||
parser = (
|
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
ExternalDatasetService.validate_api_list(payload.settings)
|
||||||
.add_argument(
|
|
||||||
"name",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
help="type is required. Name must be between 1 to 100 characters.",
|
|
||||||
type=_validate_name,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"settings",
|
|
||||||
type=dict,
|
|
||||||
location="json",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
ExternalDatasetService.validate_api_list(args["settings"])
|
|
||||||
|
|
||||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
external_knowledge_api_id=external_knowledge_api_id,
|
external_knowledge_api_id=external_knowledge_api_id,
|
||||||
args=args,
|
args=payload.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return external_knowledge_api.to_dict(), 200
|
return external_knowledge_api.to_dict(), 200
|
||||||
|
|
@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource):
|
||||||
class ExternalDatasetCreateApi(Resource):
|
class ExternalDatasetCreateApi(Resource):
|
||||||
@console_ns.doc("create_external_dataset")
|
@console_ns.doc("create_external_dataset")
|
||||||
@console_ns.doc(description="Create external knowledge dataset")
|
@console_ns.doc(description="Create external knowledge dataset")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CreateExternalDatasetRequest",
|
|
||||||
{
|
|
||||||
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
|
|
||||||
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
|
|
||||||
"name": fields.String(required=True, description="Dataset name"),
|
|
||||||
"description": fields.String(description="Dataset description"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
|
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
|
||||||
@console_ns.response(400, "Invalid parameters")
|
@console_ns.response(400, "Invalid parameters")
|
||||||
@console_ns.response(403, "Permission denied")
|
@console_ns.response(403, "Permission denied")
|
||||||
|
|
@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
args = payload.model_dump(exclude_none=True)
|
||||||
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"name",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
help="name is required. Name must be between 1 to 100 characters.",
|
|
||||||
type=_validate_name,
|
|
||||||
)
|
|
||||||
.add_argument("description", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
|
|
@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||||
@console_ns.doc("test_external_knowledge_retrieval")
|
@console_ns.doc("test_external_knowledge_retrieval")
|
||||||
@console_ns.doc(description="Test external knowledge retrieval for dataset")
|
@console_ns.doc(description="Test external knowledge retrieval for dataset")
|
||||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"ExternalHitTestingRequest",
|
|
||||||
{
|
|
||||||
"query": fields.String(required=True, description="Query text for testing"),
|
|
||||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
|
||||||
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "External hit testing completed successfully")
|
@console_ns.response(200, "External hit testing completed successfully")
|
||||||
@console_ns.response(404, "Dataset not found")
|
@console_ns.response(404, "Dataset not found")
|
||||||
@console_ns.response(400, "Invalid parameters")
|
@console_ns.response(400, "Invalid parameters")
|
||||||
|
|
@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
parser = (
|
payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
HitTestingService.hit_testing_args_check(payload.model_dump())
|
||||||
.add_argument("query", type=str, location="json")
|
|
||||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
|
||||||
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = HitTestingService.external_retrieve(
|
response = HitTestingService.external_retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=args["query"],
|
query=payload.query,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
external_retrieval_model=payload.external_retrieval_model,
|
||||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
metadata_filtering_conditions=payload.metadata_filtering_conditions,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource):
|
||||||
# this api is only for internal testing
|
# this api is only for internal testing
|
||||||
@console_ns.doc("bedrock_retrieval_test")
|
@console_ns.doc("bedrock_retrieval_test")
|
||||||
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
|
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"BedrockRetrievalTestRequest",
|
|
||||||
{
|
|
||||||
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
|
|
||||||
"query": fields.String(required=True, description="Query text"),
|
|
||||||
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Bedrock retrieval test completed")
|
@console_ns.response(200, "Bedrock retrieval test completed")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"query",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
)
|
|
||||||
.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Call the knowledge retrieval service
|
# Call the knowledge retrieval service
|
||||||
result = ExternalDatasetTestService.knowledge_retrieval(
|
result = ExternalDatasetTestService.knowledge_retrieval(
|
||||||
args["retrieval_setting"], args["query"], args["knowledge_id"]
|
payload.retrieval_setting, payload.query, payload.knowledge_id
|
||||||
)
|
)
|
||||||
return result, 200
|
return result, 200
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,17 @@
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
from libs.login import login_required
|
||||||
from controllers.console.wraps import (
|
|
||||||
|
from .. import console_ns
|
||||||
|
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||||
|
from ..wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_rate_limit_check,
|
cloud_edition_billing_rate_limit_check,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from libs.login import login_required
|
|
||||||
|
register_schema_model(console_ns, HitTestingPayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||||
|
|
@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||||
@console_ns.doc("test_dataset_retrieval")
|
@console_ns.doc("test_dataset_retrieval")
|
||||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"HitTestingRequest",
|
|
||||||
{
|
|
||||||
"query": fields.String(required=True, description="Query text for testing"),
|
|
||||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
|
||||||
"top_k": fields.Integer(description="Number of top results to return"),
|
|
||||||
"score_threshold": fields.Float(description="Score threshold for filtering results"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Hit testing completed successfully")
|
@console_ns.response(200, "Hit testing completed successfully")
|
||||||
@console_ns.response(404, "Dataset not found")
|
@console_ns.response(404, "Dataset not found")
|
||||||
@console_ns.response(400, "Invalid parameters")
|
@console_ns.response(400, "Invalid parameters")
|
||||||
|
|
@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||||
args = self.parse_args()
|
payload = HitTestingPayload.model_validate(console_ns.payload or {})
|
||||||
|
args = payload.model_dump(exclude_none=True)
|
||||||
self.hit_testing_args_check(args)
|
self.hit_testing_args_check(args)
|
||||||
|
|
||||||
return self.perform_hit_testing(dataset, args)
|
return self.perform_hit_testing(dataset, args)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask_restx import marshal, reqparse
|
from flask_restx import marshal, reqparse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
|
@ -27,6 +29,13 @@ from services.hit_testing_service import HitTestingService
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HitTestingPayload(BaseModel):
|
||||||
|
query: str = Field(max_length=250)
|
||||||
|
retrieval_model: dict[str, Any] | None = None
|
||||||
|
external_retrieval_model: dict[str, Any] | None = None
|
||||||
|
attachment_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class DatasetsHitTestingBase:
|
class DatasetsHitTestingBase:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_and_validate_dataset(dataset_id: str):
|
def get_and_validate_dataset(dataset_id: str):
|
||||||
|
|
@ -43,14 +52,15 @@ class DatasetsHitTestingBase:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hit_testing_args_check(args):
|
def hit_testing_args_check(args: dict[str, Any]):
|
||||||
HitTestingService.hit_testing_args_check(args)
|
HitTestingService.hit_testing_args_check(args)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = (
|
parser = (
|
||||||
reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
.add_argument("query", type=str, location="json")
|
.add_argument("query", type=str, required=False, location="json")
|
||||||
|
.add_argument("attachment_ids", type=list, required=False, location="json")
|
||||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
)
|
)
|
||||||
|
|
@ -62,10 +72,11 @@ class DatasetsHitTestingBase:
|
||||||
try:
|
try:
|
||||||
response = HitTestingService.retrieve(
|
response = HitTestingService.retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=args["query"],
|
query=args.get("query"),
|
||||||
account=current_user,
|
account=current_user,
|
||||||
retrieval_model=args["retrieval_model"],
|
retrieval_model=args.get("retrieval_model"),
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
external_retrieval_model=args.get("external_retrieval_model"),
|
||||||
|
attachment_ids=args.get("attachment_ids"),
|
||||||
limit=10,
|
limit=10,
|
||||||
)
|
)
|
||||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_model, register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||||
from fields.dataset_fields import dataset_metadata_fields
|
from fields.dataset_fields import dataset_metadata_fields
|
||||||
|
|
@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
from services.metadata_service import MetadataService
|
from services.metadata_service import MetadataService
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataUpdatePayload(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
|
||||||
|
register_schema_model(console_ns, MetadataUpdatePayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||||
class DatasetMetadataCreateApi(Resource):
|
class DatasetMetadataCreateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@marshal_with(dataset_metadata_fields)
|
@marshal_with(dataset_metadata_fields)
|
||||||
|
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = (
|
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
metadata_args = MetadataArgs.model_validate(args)
|
|
||||||
|
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
|
|
@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@marshal_with(dataset_metadata_fields)
|
@marshal_with(dataset_metadata_fields)
|
||||||
|
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||||
def patch(self, dataset_id, metadata_id):
|
def patch(self, dataset_id, metadata_id):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
args = parser.parse_args()
|
name = payload.name
|
||||||
name = args["name"]
|
|
||||||
|
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
metadata_id_str = str(metadata_id)
|
metadata_id_str = str(metadata_id)
|
||||||
|
|
@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
|
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource):
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument(
|
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
|
||||||
"operation_data", type=list, required=True, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
metadata_args = MetadataOperationData.model_validate(args)
|
|
||||||
|
|
||||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,63 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask import make_response, redirect, request
|
from flask import make_response, redirect, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from libs.helper import StrLen
|
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.provider_ids import DatasourceProviderID
|
from models.provider_ids import DatasourceProviderID
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceCredentialPayload(BaseModel):
|
||||||
|
name: str | None = Field(default=None, max_length=100)
|
||||||
|
credentials: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceCredentialDeletePayload(BaseModel):
|
||||||
|
credential_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceCredentialUpdatePayload(BaseModel):
|
||||||
|
credential_id: str
|
||||||
|
name: str | None = Field(default=None, max_length=100)
|
||||||
|
credentials: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceCustomClientPayload(BaseModel):
|
||||||
|
client_params: dict[str, Any] | None = None
|
||||||
|
enable_oauth_custom_client: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceDefaultPayload(BaseModel):
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceUpdateNamePayload(BaseModel):
|
||||||
|
credential_id: str
|
||||||
|
name: str = Field(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(
|
||||||
|
console_ns,
|
||||||
|
DatasourceCredentialPayload,
|
||||||
|
DatasourceCredentialDeletePayload,
|
||||||
|
DatasourceCredentialUpdatePayload,
|
||||||
|
DatasourceCustomClientPayload,
|
||||||
|
DatasourceDefaultPayload,
|
||||||
|
DatasourceUpdateNamePayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||||
|
|
||||||
|
|
||||||
parser_datasource = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||||
class DatasourceAuth(Resource):
|
class DatasourceAuth(Resource):
|
||||||
@console_ns.expect(parser_datasource)
|
@console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_datasource.parse_args()
|
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
|
||||||
|
|
@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
|
||||||
datasource_provider_service.add_datasource_api_key_provider(
|
datasource_provider_service.add_datasource_api_key_provider(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider_id=datasource_provider_id,
|
provider_id=datasource_provider_id,
|
||||||
credentials=args["credentials"],
|
credentials=payload.credentials,
|
||||||
name=args["name"],
|
name=payload.name,
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ValueError(str(ex))
|
raise ValueError(str(ex))
|
||||||
|
|
@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
|
||||||
return {"result": datasources}, 200
|
return {"result": datasources}, 200
|
||||||
|
|
||||||
|
|
||||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
|
||||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||||
class DatasourceAuthDeleteApi(Resource):
|
class DatasourceAuthDeleteApi(Resource):
|
||||||
@console_ns.expect(parser_datasource_delete)
|
@console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
|
||||||
plugin_id = datasource_provider_id.plugin_id
|
plugin_id = datasource_provider_id.plugin_id
|
||||||
provider_name = datasource_provider_id.provider_name
|
provider_name = datasource_provider_id.provider_name
|
||||||
|
|
||||||
args = parser_datasource_delete.parse_args()
|
payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.remove_datasource_credentials(
|
datasource_provider_service.remove_datasource_credentials(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
auth_id=args["credential_id"],
|
auth_id=payload.credential_id,
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
parser_datasource_update = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
|
||||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||||
class DatasourceAuthUpdateApi(Resource):
|
class DatasourceAuthUpdateApi(Resource):
|
||||||
@console_ns.expect(parser_datasource_update)
|
@console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
args = parser_datasource_update.parse_args()
|
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.update_datasource_credentials(
|
datasource_provider_service.update_datasource_credentials(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
auth_id=args["credential_id"],
|
auth_id=payload.credential_id,
|
||||||
provider=datasource_provider_id.provider_name,
|
provider=datasource_provider_id.provider_name,
|
||||||
plugin_id=datasource_provider_id.plugin_id,
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
credentials=args.get("credentials", {}),
|
credentials=payload.credentials or {},
|
||||||
name=args.get("name", None),
|
name=payload.name,
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 201
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
|
@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||||
return {"result": jsonable_encoder(datasources)}, 200
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
parser_datasource_custom = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||||
class DatasourceAuthOauthCustomClient(Resource):
|
class DatasourceAuthOauthCustomClient(Resource):
|
||||||
@console_ns.expect(parser_datasource_custom)
|
@console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_datasource_custom.parse_args()
|
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.setup_oauth_custom_client_params(
|
datasource_provider_service.setup_oauth_custom_client_params(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
client_params=args.get("client_params", {}),
|
client_params=payload.client_params or {},
|
||||||
enabled=args.get("enable_oauth_custom_client", False),
|
enabled=payload.enable_oauth_custom_client or False,
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||||
class DatasourceAuthDefaultApi(Resource):
|
class DatasourceAuthDefaultApi(Resource):
|
||||||
@console_ns.expect(parser_default)
|
@console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_default.parse_args()
|
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.set_default_datasource_provider(
|
datasource_provider_service.set_default_datasource_provider(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
credential_id=args["id"],
|
credential_id=payload.id,
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
parser_update_name = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
|
||||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||||
class DatasourceUpdateProviderNameApi(Resource):
|
class DatasourceUpdateProviderNameApi(Resource):
|
||||||
@console_ns.expect(parser_update_name)
|
@console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_update_name.parse_args()
|
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.update_datasource_provider_name(
|
datasource_provider_service.update_datasource_provider_name(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
name=args["name"],
|
name=payload.name,
|
||||||
credential_id=args["credential_id"],
|
credential_id=payload.credential_id,
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||||
class DataSourceContentPreviewApi(Resource):
|
class DataSourceContentPreviewApi(Resource):
|
||||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
|
|
@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name: str) -> str:
|
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
|
||||||
raise ValueError("Name must be between 1 to 40 characters.")
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description: str) -> str:
|
|
||||||
if len(description) > 400:
|
|
||||||
raise ValueError("Description cannot exceed 400 characters.")
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipeline/templates")
|
@console_ns.route("/rag/pipeline/templates")
|
||||||
class PipelineTemplateListApi(Resource):
|
class PipelineTemplateListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
|
||||||
return pipeline_template, 200
|
return pipeline_template, 200
|
||||||
|
|
||||||
|
|
||||||
|
class Payload(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1, max_length=40)
|
||||||
|
description: str = Field(default="", max_length=400)
|
||||||
|
icon_info: dict[str, object] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, Payload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||||
class CustomizedPipelineTemplateApi(Resource):
|
class CustomizedPipelineTemplateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def patch(self, template_id: str):
|
def patch(self, template_id: str):
|
||||||
parser = (
|
payload = Payload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
|
||||||
.add_argument(
|
|
||||||
"name",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
help="Name must be between 1 to 40 characters.",
|
|
||||||
type=_validate_name,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"description",
|
|
||||||
type=_validate_description_length,
|
|
||||||
nullable=True,
|
|
||||||
required=False,
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"icon_info",
|
|
||||||
type=dict,
|
|
||||||
location="json",
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
|
||||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||||
return 200
|
return 200
|
||||||
|
|
||||||
|
|
@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[Payload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@knowledge_pipeline_publish_enabled
|
@knowledge_pipeline_publish_enabled
|
||||||
def post(self, pipeline_id: str):
|
def post(self, pipeline_id: str):
|
||||||
parser = (
|
payload = Payload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"name",
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
help="Name must be between 1 to 40 characters.",
|
|
||||||
type=_validate_name,
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"description",
|
|
||||||
type=_validate_description_length,
|
|
||||||
nullable=True,
|
|
||||||
required=False,
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
.add_argument(
|
|
||||||
"icon_info",
|
|
||||||
type=dict,
|
|
||||||
location="json",
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from flask_restx import Resource, marshal, reqparse
|
from flask_restx import Resource, marshal
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
|
|
@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
|
||||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineDatasetImportPayload(BaseModel):
|
||||||
|
yaml_content: str
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipeline/dataset")
|
@console_ns.route("/rag/pipeline/dataset")
|
||||||
class CreateRagPipelineDatasetApi(Resource):
|
class CreateRagPipelineDatasetApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser().add_argument(
|
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||||
"yaml_content",
|
|
||||||
type=str,
|
|
||||||
nullable=False,
|
|
||||||
required=True,
|
|
||||||
help="yaml_content is required.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
|
|
@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||||
),
|
),
|
||||||
permission=DatasetPermissionEnum.ONLY_ME,
|
permission=DatasetPermissionEnum.ONLY_ME,
|
||||||
partial_member_list=None,
|
partial_member_list=None,
|
||||||
yaml_content=args["yaml_content"],
|
yaml_content=payload.yaml_content,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import NoReturn
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response, request
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
DraftWorkflowNotExist,
|
DraftWorkflowNotExist,
|
||||||
|
|
@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _create_pagination_parser():
|
def _create_pagination_parser():
|
||||||
parser = (
|
class PaginationQuery(BaseModel):
|
||||||
reqparse.RequestParser()
|
page: int = Field(default=1, ge=1, le=100_000)
|
||||||
.add_argument(
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
"page",
|
|
||||||
type=inputs.int_range(1, 100_000),
|
register_schema_models(console_ns, PaginationQuery)
|
||||||
required=False,
|
|
||||||
default=1,
|
return PaginationQuery
|
||||||
location="args",
|
|
||||||
help="the page of data requested",
|
|
||||||
)
|
class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
name: str | None = None
|
||||||
)
|
value: Any | None = None
|
||||||
return parser
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||||
|
|
||||||
|
|
||||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||||
|
|
@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get draft workflow
|
Get draft workflow
|
||||||
"""
|
"""
|
||||||
parser = _create_pagination_parser()
|
pagination = _create_pagination_parser()
|
||||||
args = parser.parse_args()
|
query = pagination.model_validate(request.args.to_dict())
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
|
@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||||
)
|
)
|
||||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||||
app_id=pipeline.id,
|
app_id=pipeline.id,
|
||||||
page=args.page,
|
page=query.page,
|
||||||
limit=args.limit,
|
limit=query.limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
return workflow_vars
|
return workflow_vars
|
||||||
|
|
@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
|
||||||
|
|
||||||
@_api_prerequisite
|
@_api_prerequisite
|
||||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||||
# Request payload for file types:
|
# Request payload for file types:
|
||||||
#
|
#
|
||||||
|
|
@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
|
||||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
# }
|
# }
|
||||||
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
session=db.session(),
|
session=db.session(),
|
||||||
)
|
)
|
||||||
args = parser.parse_args(strict=True)
|
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
|
||||||
|
args = payload.model_dump(exclude_none=True)
|
||||||
|
|
||||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
from flask import request
|
||||||
|
from flask_restx import Resource, marshal_with # type: ignore
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
|
|
@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
|
||||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportPayload(BaseModel):
|
||||||
|
mode: str
|
||||||
|
yaml_content: str | None = None
|
||||||
|
yaml_url: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
icon_type: str | None = None
|
||||||
|
icon: str | None = None
|
||||||
|
icon_background: str | None = None
|
||||||
|
pipeline_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IncludeSecretQuery(BaseModel):
|
||||||
|
include_secret: str = Field(default="false")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/imports")
|
@console_ns.route("/rag/pipelines/imports")
|
||||||
class RagPipelineImportApi(Resource):
|
class RagPipelineImportApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
@marshal_with(pipeline_import_fields)
|
@marshal_with(pipeline_import_fields)
|
||||||
|
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("mode", type=str, required=True, location="json")
|
|
||||||
.add_argument("yaml_content", type=str, location="json")
|
|
||||||
.add_argument("yaml_url", type=str, location="json")
|
|
||||||
.add_argument("name", type=str, location="json")
|
|
||||||
.add_argument("description", type=str, location="json")
|
|
||||||
.add_argument("icon_type", type=str, location="json")
|
|
||||||
.add_argument("icon", type=str, location="json")
|
|
||||||
.add_argument("icon_background", type=str, location="json")
|
|
||||||
.add_argument("pipeline_id", type=str, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|
@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
result = import_service.import_rag_pipeline(
|
result = import_service.import_rag_pipeline(
|
||||||
account=account,
|
account=account,
|
||||||
import_mode=args["mode"],
|
import_mode=payload.mode,
|
||||||
yaml_content=args.get("yaml_content"),
|
yaml_content=payload.yaml_content,
|
||||||
yaml_url=args.get("yaml_url"),
|
yaml_url=payload.yaml_url,
|
||||||
pipeline_id=args.get("pipeline_id"),
|
pipeline_id=payload.pipeline_id,
|
||||||
dataset_name=args.get("name"),
|
dataset_name=payload.name,
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
# Add include_secret params
|
# Add include_secret params
|
||||||
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
|
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
export_service = RagPipelineDslService(session)
|
export_service = RagPipelineDslService(session)
|
||||||
result = export_service.export_rag_pipeline_dsl(
|
result = export_service.export_rag_pipeline_dsl(
|
||||||
pipeline=pipeline, include_secret=args["include_secret"] == "true"
|
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"data": result}, 200
|
return {"data": result}, 200
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,16 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import cast
|
from typing import Any, Literal, cast
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
|
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||||
from flask_restx.inputs import int_range # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
ConversationCompletedError,
|
ConversationCompletedError,
|
||||||
|
|
@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
|
||||||
workflow_run_pagination_fields,
|
workflow_run_pagination_fields,
|
||||||
)
|
)
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField
|
||||||
from libs.login import current_account_with_tenant, current_user, login_required
|
from libs.login import current_account_with_tenant, current_user, login_required
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
|
|
@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DraftWorkflowSyncPayload(BaseModel):
|
||||||
|
graph: dict[str, Any]
|
||||||
|
hash: str | None = None
|
||||||
|
environment_variables: list[dict[str, Any]] | None = None
|
||||||
|
conversation_variables: list[dict[str, Any]] | None = None
|
||||||
|
rag_pipeline_variables: list[dict[str, Any]] | None = None
|
||||||
|
features: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class NodeRunPayload(BaseModel):
|
||||||
|
inputs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class NodeRunRequiredPayload(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNodeRunPayload(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
datasource_type: str
|
||||||
|
credential_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DraftWorkflowRunPayload(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
datasource_type: str
|
||||||
|
datasource_info_list: list[dict[str, Any]]
|
||||||
|
start_node_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
|
||||||
|
is_preview: bool = False
|
||||||
|
response_mode: Literal["streaming", "blocking"] = "streaming"
|
||||||
|
original_document_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultBlockConfigQuery(BaseModel):
|
||||||
|
q: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowListQuery(BaseModel):
|
||||||
|
page: int = Field(default=1, ge=1, le=99999)
|
||||||
|
limit: int = Field(default=10, ge=1, le=100)
|
||||||
|
user_id: str | None = None
|
||||||
|
named_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowUpdatePayload(BaseModel):
|
||||||
|
marked_name: str | None = Field(default=None, max_length=20)
|
||||||
|
marked_comment: str | None = Field(default=None, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeIdQuery(BaseModel):
|
||||||
|
node_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunQuery(BaseModel):
|
||||||
|
last_id: UUID | None = None
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceVariablesPayload(BaseModel):
|
||||||
|
datasource_type: str
|
||||||
|
datasource_info: dict[str, Any]
|
||||||
|
start_node_id: str
|
||||||
|
start_node_title: str
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(
|
||||||
|
console_ns,
|
||||||
|
DraftWorkflowSyncPayload,
|
||||||
|
NodeRunPayload,
|
||||||
|
NodeRunRequiredPayload,
|
||||||
|
DatasourceNodeRunPayload,
|
||||||
|
DraftWorkflowRunPayload,
|
||||||
|
PublishedWorkflowRunPayload,
|
||||||
|
DefaultBlockConfigQuery,
|
||||||
|
WorkflowListQuery,
|
||||||
|
WorkflowUpdatePayload,
|
||||||
|
NodeIdQuery,
|
||||||
|
WorkflowRunQuery,
|
||||||
|
DatasourceVariablesPayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
||||||
class DraftRagPipelineApi(Resource):
|
class DraftRagPipelineApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
|
||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
parser = (
|
payload_dict = console_ns.payload or {}
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("hash", type=str, required=False, location="json")
|
|
||||||
.add_argument("environment_variables", type=list, required=False, location="json")
|
|
||||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
|
||||||
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
elif "text/plain" in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
data = json.loads(request.data.decode("utf-8"))
|
data = json.loads(request.data.decode("utf-8"))
|
||||||
|
|
@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
|
||||||
if not isinstance(data.get("graph"), dict):
|
if not isinstance(data.get("graph"), dict):
|
||||||
raise ValueError("graph is not a dict")
|
raise ValueError("graph is not a dict")
|
||||||
|
|
||||||
args = {
|
payload_dict = {
|
||||||
"graph": data.get("graph"),
|
"graph": data.get("graph"),
|
||||||
"features": data.get("features"),
|
"features": data.get("features"),
|
||||||
"hash": data.get("hash"),
|
"hash": data.get("hash"),
|
||||||
|
|
@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
|
||||||
else:
|
else:
|
||||||
abort(415)
|
abort(415)
|
||||||
|
|
||||||
|
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
environment_variables_list = args.get("environment_variables") or []
|
environment_variables_list = payload.environment_variables or []
|
||||||
environment_variables = [
|
environment_variables = [
|
||||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||||
]
|
]
|
||||||
conversation_variables_list = args.get("conversation_variables") or []
|
conversation_variables_list = payload.conversation_variables or []
|
||||||
conversation_variables = [
|
conversation_variables = [
|
||||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||||
]
|
]
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
graph=args["graph"],
|
graph=payload.graph,
|
||||||
unique_hash=args.get("hash"),
|
unique_hash=payload.hash,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
environment_variables=environment_variables,
|
environment_variables=environment_variables,
|
||||||
conversation_variables=conversation_variables,
|
conversation_variables=conversation_variables,
|
||||||
rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
|
rag_pipeline_variables=payload.rag_pipeline_variables or [],
|
||||||
)
|
)
|
||||||
except WorkflowHashNotEqualError:
|
except WorkflowHashNotEqualError:
|
||||||
raise DraftWorkflowNotSync()
|
raise DraftWorkflowNotSync()
|
||||||
|
|
@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||||
@console_ns.expect(parser_run)
|
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_run.parse_args()
|
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||||
|
args = payload.model_dump(exclude_none=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = PipelineGenerateService.generate_single_iteration(
|
response = PipelineGenerateService.generate_single_iteration(
|
||||||
|
|
@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||||
@console_ns.expect(parser_run)
|
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_run.parse_args()
|
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||||
|
args = payload.model_dump(exclude_none=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = PipelineGenerateService.generate_single_loop(
|
response = PipelineGenerateService.generate_single_loop(
|
||||||
|
|
@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
parser_draft_run = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
|
||||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
|
||||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||||
class DraftRagPipelineRunApi(Resource):
|
class DraftRagPipelineRunApi(Resource):
|
||||||
@console_ns.expect(parser_draft_run)
|
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_draft_run.parse_args()
|
payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||||
|
args = payload.model_dump()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = PipelineGenerateService.generate(
|
response = PipelineGenerateService.generate(
|
||||||
|
|
@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
|
||||||
raise InvokeRateLimitHttpError(ex.description)
|
raise InvokeRateLimitHttpError(ex.description)
|
||||||
|
|
||||||
|
|
||||||
parser_published_run = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
|
||||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
|
||||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
|
||||||
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
|
||||||
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
|
||||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||||
class PublishedRagPipelineRunApi(Resource):
|
class PublishedRagPipelineRunApi(Resource):
|
||||||
@console_ns.expect(parser_published_run)
|
@console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_published_run.parse_args()
|
payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||||
|
args = payload.model_dump(exclude_none=True)
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = payload.response_mode == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = PipelineGenerateService.generate(
|
response = PipelineGenerateService.generate(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
args=args,
|
args=args,
|
||||||
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
|
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
|
||||||
#
|
#
|
||||||
# return result
|
# return result
|
||||||
#
|
#
|
||||||
parser_rag_run = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
|
||||||
.add_argument("credential_id", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||||
@console_ns.expect(parser_rag_run)
|
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_rag_run.parse_args()
|
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
|
||||||
if inputs is None:
|
|
||||||
raise ValueError("missing inputs")
|
|
||||||
datasource_type = args.get("datasource_type")
|
|
||||||
if datasource_type is None:
|
|
||||||
raise ValueError("missing datasource_type")
|
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
return helper.compact_generate_response(
|
return helper.compact_generate_response(
|
||||||
|
|
@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||||
rag_pipeline_service.run_datasource_workflow_node(
|
rag_pipeline_service.run_datasource_workflow_node(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
user_inputs=inputs,
|
user_inputs=payload.inputs,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
datasource_type=datasource_type,
|
datasource_type=payload.datasource_type,
|
||||||
is_published=False,
|
is_published=False,
|
||||||
credential_id=args.get("credential_id"),
|
credential_id=payload.credential_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||||
@console_ns.expect(parser_rag_run)
|
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
|
|
@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_rag_run.parse_args()
|
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
|
||||||
if inputs is None:
|
|
||||||
raise ValueError("missing inputs")
|
|
||||||
datasource_type = args.get("datasource_type")
|
|
||||||
if datasource_type is None:
|
|
||||||
raise ValueError("missing datasource_type")
|
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
return helper.compact_generate_response(
|
return helper.compact_generate_response(
|
||||||
|
|
@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||||
rag_pipeline_service.run_datasource_workflow_node(
|
rag_pipeline_service.run_datasource_workflow_node(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
user_inputs=inputs,
|
user_inputs=payload.inputs,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
datasource_type=datasource_type,
|
datasource_type=payload.datasource_type,
|
||||||
is_published=False,
|
is_published=False,
|
||||||
credential_id=args.get("credential_id"),
|
credential_id=payload.credential_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser_run_api = reqparse.RequestParser().add_argument(
|
|
||||||
"inputs", type=dict, required=True, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||||
class RagPipelineDraftNodeRunApi(Resource):
|
class RagPipelineDraftNodeRunApi(Resource):
|
||||||
@console_ns.expect(parser_run_api)
|
@console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
|
|
@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_run_api.parse_args()
|
payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
|
||||||
|
inputs = payload.inputs
|
||||||
inputs = args.get("inputs")
|
|
||||||
if inputs == None:
|
|
||||||
raise ValueError("missing inputs")
|
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
|
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
|
||||||
|
|
@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||||
return rag_pipeline_service.get_default_block_configs()
|
return rag_pipeline_service.get_default_block_configs()
|
||||||
|
|
||||||
|
|
||||||
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||||
@console_ns.expect(parser_default)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
args = parser_default.parse_args()
|
query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
|
||||||
|
|
||||||
q = args.get("q")
|
|
||||||
|
|
||||||
filters = None
|
filters = None
|
||||||
if q:
|
if query.q:
|
||||||
try:
|
try:
|
||||||
filters = json.loads(args.get("q", ""))
|
filters = json.loads(query.q)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError("Invalid filters")
|
raise ValueError("Invalid filters")
|
||||||
|
|
||||||
|
|
@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||||
|
|
||||||
|
|
||||||
parser_wf = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
|
||||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
|
||||||
.add_argument("user_id", type=str, required=False, location="args")
|
|
||||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||||
class PublishedAllRagPipelineApi(Resource):
|
class PublishedAllRagPipelineApi(Resource):
|
||||||
@console_ns.expect(parser_wf)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_wf.parse_args()
|
query = WorkflowListQuery.model_validate(request.args.to_dict())
|
||||||
page = args["page"]
|
|
||||||
limit = args["limit"]
|
page = query.page
|
||||||
user_id = args.get("user_id")
|
limit = query.limit
|
||||||
named_only = args.get("named_only", False)
|
user_id = query.user_id
|
||||||
|
named_only = query.named_only
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
if user_id != current_user.id:
|
if user_id != current_user.id:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
user_id = cast(str, user_id)
|
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|
@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
parser_wf_id = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("marked_name", type=str, required=False, location="json")
|
|
||||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||||
class RagPipelineByIdApi(Resource):
|
class RagPipelineByIdApi(Resource):
|
||||||
@console_ns.expect(parser_wf_id)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
|
||||||
# Check permission
|
# Check permission
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
args = parser_wf_id.parse_args()
|
payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
# Validate name and comment length
|
|
||||||
if args.marked_name and len(args.marked_name) > 20:
|
|
||||||
raise ValueError("Marked name cannot exceed 20 characters")
|
|
||||||
if args.marked_comment and len(args.marked_comment) > 100:
|
|
||||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
|
||||||
|
|
||||||
# Prepare update data
|
|
||||||
update_data = {}
|
|
||||||
if args.get("marked_name") is not None:
|
|
||||||
update_data["marked_name"] = args["marked_name"]
|
|
||||||
if args.get("marked_comment") is not None:
|
|
||||||
update_data["marked_comment"] = args["marked_comment"]
|
|
||||||
|
|
||||||
if not update_data:
|
if not update_data:
|
||||||
return {"message": "No valid fields to update"}, 400
|
return {"message": "No valid fields to update"}, 400
|
||||||
|
|
@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
|
|
||||||
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||||
class PublishedRagPipelineSecondStepApi(Resource):
|
class PublishedRagPipelineSecondStepApi(Resource):
|
||||||
@console_ns.expect(parser_parameters)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
args = parser_parameters.parse_args()
|
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||||
node_id = args.get("node_id")
|
node_id = query.node_id
|
||||||
if not node_id:
|
|
||||||
raise ValueError("Node ID is required")
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
|
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
|
||||||
return {
|
return {
|
||||||
|
|
@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||||
class PublishedRagPipelineFirstStepApi(Resource):
|
class PublishedRagPipelineFirstStepApi(Resource):
|
||||||
@console_ns.expect(parser_parameters)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get first step parameters of rag pipeline
|
Get first step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
args = parser_parameters.parse_args()
|
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||||
node_id = args.get("node_id")
|
node_id = query.node_id
|
||||||
if not node_id:
|
|
||||||
raise ValueError("Node ID is required")
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
|
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
|
||||||
return {
|
return {
|
||||||
|
|
@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||||
class DraftRagPipelineFirstStepApi(Resource):
|
class DraftRagPipelineFirstStepApi(Resource):
|
||||||
@console_ns.expect(parser_parameters)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get first step parameters of rag pipeline
|
Get first step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
args = parser_parameters.parse_args()
|
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||||
node_id = args.get("node_id")
|
node_id = query.node_id
|
||||||
if not node_id:
|
|
||||||
raise ValueError("Node ID is required")
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
|
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
|
||||||
return {
|
return {
|
||||||
|
|
@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||||
class DraftRagPipelineSecondStepApi(Resource):
|
class DraftRagPipelineSecondStepApi(Resource):
|
||||||
@console_ns.expect(parser_parameters)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
args = parser_parameters.parse_args()
|
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||||
node_id = args.get("node_id")
|
node_id = query.node_id
|
||||||
if not node_id:
|
|
||||||
raise ValueError("Node ID is required")
|
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
|
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
|
||||||
|
|
@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
parser_wf_run = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("last_id", type=uuid_value, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||||
class RagPipelineWorkflowRunListApi(Resource):
|
class RagPipelineWorkflowRunListApi(Resource):
|
||||||
@console_ns.expect(parser_wf_run)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get workflow run list
|
Get workflow run list
|
||||||
"""
|
"""
|
||||||
args = parser_wf_run.parse_args()
|
query = WorkflowRunQuery.model_validate(
|
||||||
|
{
|
||||||
|
"last_id": request.args.get("last_id"),
|
||||||
|
"limit": request.args.get("limit", type=int, default=20),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
args = {
|
||||||
|
"last_id": str(query.last_id) if query.last_id else None,
|
||||||
|
"limit": query.limit,
|
||||||
|
}
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
|
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
|
||||||
|
|
@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
parser_var = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
|
||||||
.add_argument("datasource_info", type=dict, required=True, location="json")
|
|
||||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
|
||||||
.add_argument("start_node_title", type=str, required=True, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||||
class RagPipelineDatasourceVariableApi(Resource):
|
class RagPipelineDatasourceVariableApi(Resource):
|
||||||
@console_ns.expect(parser_var)
|
@console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
||||||
Set datasource variables
|
Set datasource variables
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
args = parser_var.parse_args()
|
args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
workflow_node_execution = rag_pipeline_service.set_datasource_variables(
|
workflow_node_execution = rag_pipeline_service.set_datasource_variables(
|
||||||
|
|
@ -1004,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("type", type=str, location="args", required=False, default="all")
|
||||||
|
args = parser.parse_args()
|
||||||
|
type = args["type"]
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
|
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
|
||||||
return recommended_plugins
|
return recommended_plugins
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
from flask_restx import Resource, fields, reqparse
|
from typing import Literal
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.datasets.error import WebsiteCrawlError
|
from controllers.console.datasets.error import WebsiteCrawlError
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
|
@ -7,48 +12,35 @@ from libs.login import login_required
|
||||||
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
|
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlPayload(BaseModel):
|
||||||
|
provider: Literal["firecrawl", "watercrawl", "jinareader"]
|
||||||
|
url: str
|
||||||
|
options: dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlStatusQuery(BaseModel):
|
||||||
|
provider: Literal["firecrawl", "watercrawl", "jinareader"]
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/website/crawl")
|
@console_ns.route("/website/crawl")
|
||||||
class WebsiteCrawlApi(Resource):
|
class WebsiteCrawlApi(Resource):
|
||||||
@console_ns.doc("crawl_website")
|
@console_ns.doc("crawl_website")
|
||||||
@console_ns.doc(description="Crawl website content")
|
@console_ns.doc(description="Crawl website content")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"WebsiteCrawlRequest",
|
|
||||||
{
|
|
||||||
"provider": fields.String(
|
|
||||||
required=True,
|
|
||||||
description="Crawl provider (firecrawl/watercrawl/jinareader)",
|
|
||||||
enum=["firecrawl", "watercrawl", "jinareader"],
|
|
||||||
),
|
|
||||||
"url": fields.String(required=True, description="URL to crawl"),
|
|
||||||
"options": fields.Raw(required=True, description="Crawl options"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Website crawl initiated successfully")
|
@console_ns.response(200, "Website crawl initiated successfully")
|
||||||
@console_ns.response(400, "Invalid crawl parameters")
|
@console_ns.response(400, "Invalid crawl parameters")
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"provider",
|
|
||||||
type=str,
|
|
||||||
choices=["firecrawl", "watercrawl", "jinareader"],
|
|
||||||
required=True,
|
|
||||||
nullable=True,
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("url", type=str, required=True, nullable=True, location="json")
|
|
||||||
.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Create typed request and validate
|
# Create typed request and validate
|
||||||
try:
|
try:
|
||||||
api_request = WebsiteCrawlApiRequest.from_args(args)
|
api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump())
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise WebsiteCrawlError(str(e))
|
raise WebsiteCrawlError(str(e))
|
||||||
|
|
||||||
|
|
@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource):
|
||||||
@console_ns.doc("get_crawl_status")
|
@console_ns.doc("get_crawl_status")
|
||||||
@console_ns.doc(description="Get website crawl status")
|
@console_ns.doc(description="Get website crawl status")
|
||||||
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
|
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
|
||||||
|
@console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__])
|
||||||
@console_ns.response(200, "Crawl status retrieved successfully")
|
@console_ns.response(200, "Crawl status retrieved successfully")
|
||||||
@console_ns.response(404, "Crawl job not found")
|
@console_ns.response(404, "Crawl job not found")
|
||||||
@console_ns.response(400, "Invalid provider")
|
@console_ns.response(400, "Invalid provider")
|
||||||
|
|
@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, job_id: str):
|
def get(self, job_id: str):
|
||||||
parser = reqparse.RequestParser().add_argument(
|
args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict())
|
||||||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Create typed request and validate
|
# Create typed request and validate
|
||||||
try:
|
try:
|
||||||
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
|
api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise WebsiteCrawlError(str(e))
|
raise WebsiteCrawlError(str(e))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
AppUnavailableError,
|
AppUnavailableError,
|
||||||
AudioTooLargeError,
|
AudioTooLargeError,
|
||||||
|
|
@ -31,6 +33,16 @@ from .. import console_ns
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TextToAudioPayload(BaseModel):
|
||||||
|
message_id: str | None = None
|
||||||
|
voice: str | None = None
|
||||||
|
text: str | None = None
|
||||||
|
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_model(console_ns, TextToAudioPayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
|
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
|
||||||
endpoint="installed_app_audio",
|
endpoint="installed_app_audio",
|
||||||
|
|
@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource):
|
||||||
endpoint="installed_app_text",
|
endpoint="installed_app_text",
|
||||||
)
|
)
|
||||||
class ChatTextApi(InstalledAppResource):
|
class ChatTextApi(InstalledAppResource):
|
||||||
|
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
from flask_restx import reqparse
|
|
||||||
|
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
try:
|
try:
|
||||||
parser = (
|
payload = TextToAudioPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("message_id", type=str, required=False, location="json")
|
|
||||||
.add_argument("voice", type=str, location="json")
|
|
||||||
.add_argument("text", type=str, location="json")
|
|
||||||
.add_argument("streaming", type=bool, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
message_id = args.get("message_id", None)
|
message_id = payload.message_id
|
||||||
text = args.get("text", None)
|
text = payload.text
|
||||||
voice = args.get("voice", None)
|
voice = payload.voice
|
||||||
|
|
||||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from flask_restx import reqparse
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
AppUnavailableError,
|
AppUnavailableError,
|
||||||
CompletionRequestError,
|
CompletionRequestError,
|
||||||
|
|
@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import uuid_value
|
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
@ -38,28 +40,56 @@ from .. import console_ns
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionMessageExplorePayload(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
query: str = ""
|
||||||
|
files: list[dict[str, Any]] | None = None
|
||||||
|
response_mode: Literal["blocking", "streaming"] | None = None
|
||||||
|
retriever_from: str = Field(default="explore_app")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessagePayload(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
query: str
|
||||||
|
files: list[dict[str, Any]] | None = None
|
||||||
|
conversation_id: str | None = None
|
||||||
|
parent_message_id: str | None = None
|
||||||
|
retriever_from: str = Field(default="explore_app")
|
||||||
|
|
||||||
|
@field_validator("conversation_id", "parent_message_id", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
|
||||||
|
"""
|
||||||
|
Accept blank IDs and validate UUID format when provided.
|
||||||
|
"""
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return helper.uuid_value(value)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError("must be a valid UUID") from exc
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
|
||||||
|
|
||||||
|
|
||||||
# define completion api for user
|
# define completion api for user
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/installed-apps/<uuid:installed_app_id>/completion-messages",
|
"/installed-apps/<uuid:installed_app_id>/completion-messages",
|
||||||
endpoint="installed_app_completion",
|
endpoint="installed_app_completion",
|
||||||
)
|
)
|
||||||
class CompletionApi(InstalledAppResource):
|
class CompletionApi(InstalledAppResource):
|
||||||
|
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != AppMode.COMPLETION:
|
if app_model.mode != AppMode.COMPLETION:
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = (
|
payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
args = payload.model_dump(exclude_none=True)
|
||||||
.add_argument("inputs", type=dict, required=True, location="json")
|
|
||||||
.add_argument("query", type=str, location="json", default="")
|
|
||||||
.add_argument("files", type=list, required=False, location="json")
|
|
||||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
|
||||||
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = payload.response_mode == "streaming"
|
||||||
args["auto_generate_name"] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
installed_app.last_used_at = naive_utc_now()
|
installed_app.last_used_at = naive_utc_now()
|
||||||
|
|
@ -123,22 +153,15 @@ class CompletionStopApi(InstalledAppResource):
|
||||||
endpoint="installed_app_chat_completion",
|
endpoint="installed_app_chat_completion",
|
||||||
)
|
)
|
||||||
class ChatApi(InstalledAppResource):
|
class ChatApi(InstalledAppResource):
|
||||||
|
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = (
|
payload = ChatMessagePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
args = payload.model_dump(exclude_none=True)
|
||||||
.add_argument("inputs", type=dict, required=True, location="json")
|
|
||||||
.add_argument("query", type=str, required=True, location="json")
|
|
||||||
.add_argument("files", type=list, required=False, location="json")
|
|
||||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
|
||||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
|
||||||
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
args["auto_generate_name"] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,18 @@
|
||||||
from flask_restx import marshal_with, reqparse
|
from typing import Any
|
||||||
from flask_restx.inputs import int_range
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import marshal_with
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.explore.error import NotChatAppError
|
from controllers.console.explore.error import NotChatAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
@ -19,29 +23,51 @@ from services.web_conversation_service import WebConversationService
|
||||||
from .. import console_ns
|
from .. import console_ns
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationListQuery(BaseModel):
|
||||||
|
last_id: UUIDStrOrEmpty | None = None
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
pinned: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationRenamePayload(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
auto_generate: bool = False
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_name_requirement(self):
|
||||||
|
if not self.auto_generate:
|
||||||
|
if self.name is None or not self.name.strip():
|
||||||
|
raise ValueError("name is required when auto_generate is false")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/installed-apps/<uuid:installed_app_id>/conversations",
|
"/installed-apps/<uuid:installed_app_id>/conversations",
|
||||||
endpoint="installed_app_conversations",
|
endpoint="installed_app_conversations",
|
||||||
)
|
)
|
||||||
class ConversationListApi(InstalledAppResource):
|
class ConversationListApi(InstalledAppResource):
|
||||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||||
|
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = (
|
raw_args: dict[str, Any] = {
|
||||||
reqparse.RequestParser()
|
"last_id": request.args.get("last_id"),
|
||||||
.add_argument("last_id", type=uuid_value, location="args")
|
"limit": request.args.get("limit", default=20, type=int),
|
||||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
"pinned": request.args.get("pinned"),
|
||||||
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
}
|
||||||
)
|
if raw_args["last_id"] is None:
|
||||||
args = parser.parse_args()
|
raw_args["last_id"] = None
|
||||||
|
pinned_value = raw_args["pinned"]
|
||||||
pinned = None
|
if isinstance(pinned_value, str):
|
||||||
if "pinned" in args and args["pinned"] is not None:
|
raw_args["pinned"] = pinned_value == "true"
|
||||||
pinned = args["pinned"] == "true"
|
args = ConversationListQuery.model_validate(raw_args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
|
|
@ -51,10 +77,10 @@ class ConversationListApi(InstalledAppResource):
|
||||||
session=session,
|
session=session,
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
last_id=args["last_id"],
|
last_id=str(args.last_id) if args.last_id else None,
|
||||||
limit=args["limit"],
|
limit=args.limit,
|
||||||
invoke_from=InvokeFrom.EXPLORE,
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
pinned=pinned,
|
pinned=args.pinned,
|
||||||
)
|
)
|
||||||
except LastConversationNotExistsError:
|
except LastConversationNotExistsError:
|
||||||
raise NotFound("Last Conversation Not Exists.")
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
|
@ -88,6 +114,7 @@ class ConversationApi(InstalledAppResource):
|
||||||
)
|
)
|
||||||
class ConversationRenameApi(InstalledAppResource):
|
class ConversationRenameApi(InstalledAppResource):
|
||||||
@marshal_with(simple_conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
|
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||||
def post(self, installed_app, c_id):
|
def post(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
|
@ -96,18 +123,13 @@ class ConversationRenameApi(InstalledAppResource):
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
parser = (
|
payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("name", type=str, required=False, location="json")
|
|
||||||
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise ValueError("current_user must be an Account instance")
|
raise ValueError("current_user must be an Account instance")
|
||||||
return ConversationService.rename(
|
return ConversationService.rename(
|
||||||
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
app_model, conversation_id, current_user, payload.name, payload.auto_generate
|
||||||
)
|
)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import and_, select
|
from sqlalchemy import and_, select
|
||||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||||
|
|
||||||
|
|
@ -18,6 +19,15 @@ from services.account_service import TenantService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
||||||
|
class InstalledAppCreatePayload(BaseModel):
|
||||||
|
app_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class InstalledAppUpdatePayload(BaseModel):
|
||||||
|
is_pinned: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||||
if recommended_app is None:
|
if recommended_app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("Recommended app not found")
|
||||||
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||||
|
|
||||||
if app is None:
|
if app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("App entity not found")
|
||||||
|
|
||||||
if not app.is_public:
|
if not app.is_public:
|
||||||
raise Forbidden("You can't install a non-public app")
|
raise Forbidden("You can't install a non-public app")
|
||||||
|
|
||||||
installed_app = (
|
installed_app = (
|
||||||
db.session.query(InstalledApp)
|
db.session.query(InstalledApp)
|
||||||
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
|
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
|
||||||
recommended_app.install_count += 1
|
recommended_app.install_count += 1
|
||||||
|
|
||||||
new_installed_app = InstalledApp(
|
new_installed_app = InstalledApp(
|
||||||
app_id=args["app_id"],
|
app_id=payload.app_id,
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_owner_tenant_id=app.tenant_id,
|
app_owner_tenant_id=app.tenant_id,
|
||||||
is_pinned=False,
|
is_pinned=False,
|
||||||
|
|
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
|
||||||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||||
|
|
||||||
def patch(self, installed_app):
|
def patch(self, installed_app):
|
||||||
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
|
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
commit_args = False
|
commit_args = False
|
||||||
if "is_pinned" in args:
|
if payload.is_pinned is not None:
|
||||||
installed_app.is_pinned = args["is_pinned"]
|
installed_app.is_pinned = payload.is_pinned
|
||||||
commit_args = True
|
commit_args = True
|
||||||
|
|
||||||
if commit_args:
|
if commit_args:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask_restx import marshal_with, reqparse
|
from flask import request
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx import marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
AppMoreLikeThisDisabledError,
|
AppMoreLikeThisDisabledError,
|
||||||
CompletionRequestError,
|
CompletionRequestError,
|
||||||
|
|
@ -22,7 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
|
|
@ -40,12 +43,31 @@ from .. import console_ns
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageListQuery(BaseModel):
|
||||||
|
conversation_id: UUIDStrOrEmpty
|
||||||
|
first_id: UUIDStrOrEmpty | None = None
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFeedbackPayload(BaseModel):
|
||||||
|
rating: Literal["like", "dislike"] | None = None
|
||||||
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MoreLikeThisQuery(BaseModel):
|
||||||
|
response_mode: Literal["blocking", "streaming"]
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/installed-apps/<uuid:installed_app_id>/messages",
|
"/installed-apps/<uuid:installed_app_id>/messages",
|
||||||
endpoint="installed_app_messages",
|
endpoint="installed_app_messages",
|
||||||
)
|
)
|
||||||
class MessageListApi(InstalledAppResource):
|
class MessageListApi(InstalledAppResource):
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
|
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
args = MessageListQuery.model_validate(request.args.to_dict())
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
|
||||||
.add_argument("first_id", type=uuid_value, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return MessageService.pagination_by_first_id(
|
return MessageService.pagination_by_first_id(
|
||||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
app_model,
|
||||||
|
current_user,
|
||||||
|
str(args.conversation_id),
|
||||||
|
str(args.first_id) if args.first_id else None,
|
||||||
|
args.limit,
|
||||||
)
|
)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource):
|
||||||
endpoint="installed_app_message_feedback",
|
endpoint="installed_app_message_feedback",
|
||||||
)
|
)
|
||||||
class MessageFeedbackApi(InstalledAppResource):
|
class MessageFeedbackApi(InstalledAppResource):
|
||||||
|
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||||
def post(self, installed_app, message_id):
|
def post(self, installed_app, message_id):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = (
|
payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
|
||||||
.add_argument("content", type=str, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
MessageService.create_feedback(
|
MessageService.create_feedback(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
rating=args.get("rating"),
|
rating=payload.rating,
|
||||||
content=args.get("content"),
|
content=payload.content,
|
||||||
)
|
)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||||
endpoint="installed_app_more_like_this",
|
endpoint="installed_app_more_like_this",
|
||||||
)
|
)
|
||||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
|
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
|
||||||
def get(self, installed_app, message_id):
|
def get(self, installed_app, message_id):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument(
|
args = MoreLikeThisQuery.model_validate(request.args.to_dict())
|
||||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args.response_mode == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -35,20 +37,26 @@ recommended_app_list_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
class RecommendedAppsQuery(BaseModel):
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
RecommendedAppsQuery.__name__,
|
||||||
|
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/explore/apps")
|
@console_ns.route("/explore/apps")
|
||||||
class RecommendedAppListApi(Resource):
|
class RecommendedAppListApi(Resource):
|
||||||
@console_ns.expect(parser_apps)
|
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(recommended_app_list_fields)
|
@marshal_with(recommended_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
# language args
|
# language args
|
||||||
args = parser_apps.parse_args()
|
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
language = args.language
|
||||||
language = args.get("language")
|
|
||||||
if language and language in languages:
|
if language and language in languages:
|
||||||
language_prefix = language
|
language_prefix = language
|
||||||
elif current_user and current_user.interface_language:
|
elif current_user and current_user.interface_language:
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue