mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/add-qdrant-to-tidb-migration
This commit is contained in:
commit
9cefbc60f2
|
|
@ -0,0 +1,205 @@
|
|||
# Test Generation Checklist
|
||||
|
||||
Use this checklist when generating or reviewing tests for Dify frontend components.
|
||||
|
||||
## Pre-Generation
|
||||
|
||||
- [ ] Read the component source code completely
|
||||
- [ ] Identify component type (component, hook, utility, page)
|
||||
- [ ] Run `pnpm analyze-component <path>` if available
|
||||
- [ ] Note complexity score and features detected
|
||||
- [ ] Check for existing tests in the same directory
|
||||
- [ ] **Identify ALL files in the directory** that need testing (not just index)
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
|
||||
|
||||
- [ ] **NEVER generate all tests at once** - process one file at a time
|
||||
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
|
||||
- [ ] Create a todo list to track progress before starting
|
||||
- [ ] For EACH file: write → run test → verify pass → then next
|
||||
- [ ] **DO NOT proceed** to next file until current one passes
|
||||
|
||||
### Path-Level Coverage
|
||||
|
||||
- [ ] **Test ALL files** in the assigned directory/path
|
||||
- [ ] List all components, hooks, utilities that need coverage
|
||||
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
|
||||
|
||||
### Complexity Assessment
|
||||
|
||||
- [ ] Run `pnpm analyze-component <path>` for complexity score
|
||||
- [ ] **Complexity > 50**: Consider refactoring before testing
|
||||
- [ ] **500+ lines**: Consider splitting before testing
|
||||
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
|
||||
|
||||
### Integration vs Mocking
|
||||
|
||||
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
- [ ] Import real project components instead of mocking
|
||||
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
|
||||
- [ ] Prefer integration testing when using single spec file
|
||||
|
||||
## Required Test Sections
|
||||
|
||||
### All Components MUST Have
|
||||
|
||||
- [ ] **Rendering tests** - Component renders without crashing
|
||||
- [ ] **Props tests** - Required props, optional props, default values
|
||||
- [ ] **Edge cases** - null, undefined, empty values, boundaries
|
||||
|
||||
### Conditional Sections (Add When Feature Present)
|
||||
|
||||
| Feature | Add Tests For |
|
||||
|---------|---------------|
|
||||
| `useState` | Initial state, transitions, cleanup |
|
||||
| `useEffect` | Execution, dependencies, cleanup |
|
||||
| Event handlers | onClick, onChange, onSubmit, keyboard |
|
||||
| API calls | Loading, success, error states |
|
||||
| Routing | Navigation, params, query strings |
|
||||
| `useCallback`/`useMemo` | Referential equality |
|
||||
| Context | Provider values, consumer behavior |
|
||||
| Forms | Validation, submission, error display |
|
||||
|
||||
## Code Quality Checklist
|
||||
|
||||
### Structure
|
||||
|
||||
- [ ] Uses `describe` blocks to group related tests
|
||||
- [ ] Test names follow `should <behavior> when <condition>` pattern
|
||||
- [ ] AAA pattern (Arrange-Act-Assert) is clear
|
||||
- [ ] Comments explain complex test scenarios
|
||||
|
||||
### Mocks
|
||||
|
||||
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||
- [ ] Shared mock state reset in `beforeEach`
|
||||
- [ ] i18n mock returns keys (not empty strings)
|
||||
- [ ] Router mocks match actual Next.js API
|
||||
- [ ] Mocks reflect actual component conditional behavior
|
||||
- [ ] Only mock: API services, complex context providers, third-party libs
|
||||
|
||||
### Queries
|
||||
|
||||
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
|
||||
- [ ] Use `queryBy*` for absence assertions
|
||||
- [ ] Use `findBy*` for async elements
|
||||
- [ ] `getByTestId` only as last resort
|
||||
|
||||
### Async
|
||||
|
||||
- [ ] All async tests use `async/await`
|
||||
- [ ] `waitFor` wraps async assertions
|
||||
- [ ] Fake timers properly setup/teardown
|
||||
- [ ] No floating promises
|
||||
|
||||
### TypeScript
|
||||
|
||||
- [ ] No `any` types without justification
|
||||
- [ ] Mock data uses actual types from source
|
||||
- [ ] Factory functions have proper return types
|
||||
|
||||
## Coverage Goals (Per File)
|
||||
|
||||
For the current file being tested:
|
||||
|
||||
- [ ] 100% function coverage
|
||||
- [ ] 100% statement coverage
|
||||
- [ ] >95% branch coverage
|
||||
- [ ] >95% line coverage
|
||||
|
||||
## Post-Generation (Per File)
|
||||
|
||||
**Run these checks after EACH test file, not just at the end:**
|
||||
|
||||
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
|
||||
- [ ] Fix any failures immediately
|
||||
- [ ] Mark file as complete in todo list
|
||||
- [ ] Only then proceed to next file
|
||||
|
||||
### After All Files Complete
|
||||
|
||||
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test -- --coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
|
||||
## Common Issues to Watch
|
||||
|
||||
### False Positives
|
||||
|
||||
```typescript
|
||||
// ❌ Mock doesn't match actual behavior
|
||||
jest.mock('./Component', () => () => <div>Mocked</div>)
|
||||
|
||||
// ✅ Mock matches actual conditional logic
|
||||
jest.mock('./Component', () => ({ isOpen }: any) =>
|
||||
isOpen ? <div>Content</div> : null
|
||||
)
|
||||
```
|
||||
|
||||
### State Leakage
|
||||
|
||||
```typescript
|
||||
// ❌ Shared state not reset
|
||||
let mockState = false
|
||||
jest.mock('./useHook', () => () => mockState)
|
||||
|
||||
// ✅ Reset in beforeEach
|
||||
beforeEach(() => {
|
||||
mockState = false
|
||||
})
|
||||
```
|
||||
|
||||
### Async Race Conditions
|
||||
|
||||
```typescript
|
||||
// ❌ Not awaited
|
||||
it('loads data', () => {
|
||||
render(<Component />)
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// ✅ Properly awaited
|
||||
it('loads data', async () => {
|
||||
render(<Component />)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Missing Edge Cases
|
||||
|
||||
Always test these scenarios:
|
||||
|
||||
- `null` / `undefined` inputs
|
||||
- Empty strings / arrays / objects
|
||||
- Boundary values (0, -1, MAX_INT)
|
||||
- Error states
|
||||
- Loading states
|
||||
- Disabled states
|
||||
|
||||
## Quick Commands
|
||||
|
||||
```bash
|
||||
# Run specific test
|
||||
pnpm test -- path/to/file.spec.tsx
|
||||
|
||||
# Run with coverage
|
||||
pnpm test -- --coverage path/to/file.spec.tsx
|
||||
|
||||
# Watch mode
|
||||
pnpm test -- --watch path/to/file.spec.tsx
|
||||
|
||||
# Update snapshots (use sparingly)
|
||||
pnpm test -- -u path/to/file.spec.tsx
|
||||
|
||||
# Analyze component
|
||||
pnpm analyze-component path/to/component.tsx
|
||||
|
||||
# Review existing test
|
||||
pnpm analyze-component path/to/component.tsx --review
|
||||
```
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
---
|
||||
name: Dify Frontend Testing
|
||||
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
|
||||
---
|
||||
|
||||
# Dify Frontend Testing Skill
|
||||
|
||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
|
||||
|
||||
## When to Apply This Skill
|
||||
|
||||
Apply this skill when the user:
|
||||
|
||||
- Asks to **write tests** for a component, hook, or utility
|
||||
- Asks to **review existing tests** for completeness
|
||||
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
|
||||
- Requests **test coverage** improvement
|
||||
- Uses `pnpm analyze-component` output as context
|
||||
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
||||
- Wants to understand **testing patterns** in the Dify codebase
|
||||
|
||||
**Do NOT apply** when:
|
||||
|
||||
- User is asking about backend/API tests (Python/pytest)
|
||||
- User is asking about E2E tests (Playwright/Cypress)
|
||||
- User is only asking conceptual questions without code context
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Tech Stack
|
||||
|
||||
| Tool | Version | Purpose |
|
||||
|------|---------|---------|
|
||||
| Jest | 29.7 | Test runner |
|
||||
| React Testing Library | 16.0 | Component testing |
|
||||
| happy-dom | - | Test environment |
|
||||
| nock | 14.0 | HTTP mocking |
|
||||
| TypeScript | 5.x | Type safety |
|
||||
|
||||
### Key Commands
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pnpm test
|
||||
|
||||
# Watch mode
|
||||
pnpm test -- --watch
|
||||
|
||||
# Run specific file
|
||||
pnpm test -- path/to/file.spec.tsx
|
||||
|
||||
# Generate coverage report
|
||||
pnpm test -- --coverage
|
||||
|
||||
# Analyze component complexity
|
||||
pnpm analyze-component <path>
|
||||
|
||||
# Review existing test
|
||||
pnpm analyze-component <path> --review
|
||||
```
|
||||
|
||||
### File Naming
|
||||
|
||||
- Test files: `ComponentName.spec.tsx` (same directory as component)
|
||||
- Integration tests: `web/__tests__/` directory
|
||||
|
||||
## Test Structure Template
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import Component from './index'
|
||||
|
||||
// ✅ Import real project components (DO NOT mock these)
|
||||
// import Loading from '@/app/components/base/loading'
|
||||
// import { ChildComponent } from './child-component'
|
||||
|
||||
// ✅ Mock external dependencies only
|
||||
jest.mock('@/service/api')
|
||||
jest.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: jest.fn() }),
|
||||
usePathname: () => '/test',
|
||||
}))
|
||||
|
||||
// Shared state for mocks (if needed)
|
||||
let mockSharedState = false
|
||||
|
||||
describe('ComponentName', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
||||
mockSharedState = false // ✅ Reset shared state
|
||||
})
|
||||
|
||||
// Rendering tests (REQUIRED)
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
// Arrange
|
||||
const props = { title: 'Test' }
|
||||
|
||||
// Act
|
||||
render(<Component {...props} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Test')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Props tests (REQUIRED)
|
||||
describe('Props', () => {
|
||||
it('should apply custom className', () => {
|
||||
render(<Component className="custom" />)
|
||||
expect(screen.getByRole('button')).toHaveClass('custom')
|
||||
})
|
||||
})
|
||||
|
||||
// User Interactions
|
||||
describe('User Interactions', () => {
|
||||
it('should handle click events', () => {
|
||||
const handleClick = jest.fn()
|
||||
render(<Component onClick={handleClick} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// Edge Cases (REQUIRED)
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null data', () => {
|
||||
render(<Component data={null} />)
|
||||
expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
render(<Component items={[]} />)
|
||||
expect(screen.getByText(/empty/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Testing Workflow (CRITICAL)
|
||||
|
||||
### ⚠️ Incremental Approach Required
|
||||
|
||||
**NEVER generate all test files at once.** For complex components or multi-file directories:
|
||||
|
||||
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
|
||||
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
|
||||
1. **Verify before proceeding**: Do NOT continue to next file until current passes
|
||||
|
||||
```
|
||||
For each file:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Write test │
|
||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
||||
│ 3. PASS? → Mark complete, next file │
|
||||
│ FAIL? → Fix first, then continue │
|
||||
└────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Complexity-Based Order
|
||||
|
||||
Process in this order for multi-file testing:
|
||||
|
||||
1. 🟢 Utility functions (simplest)
|
||||
1. 🟢 Custom hooks
|
||||
1. 🟡 Simple components (presentational)
|
||||
1. 🟡 Medium components (state, effects)
|
||||
1. 🔴 Complex components (API, routing)
|
||||
1. 🔴 Integration tests (index files - last)
|
||||
|
||||
### When to Refactor First
|
||||
|
||||
- **Complexity > 50**: Break into smaller pieces before testing
|
||||
- **500+ lines**: Consider splitting before testing
|
||||
- **Many dependencies**: Extract logic into hooks first
|
||||
|
||||
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Path-Level Testing (Directory Testing)
|
||||
|
||||
When assigned to test a directory/path, test **ALL content** within that path:
|
||||
|
||||
- Test all components, hooks, utilities in the directory (not just `index` file)
|
||||
- Use incremental approach: one file at a time, verify each before proceeding
|
||||
- Goal: 100% coverage of ALL files in the directory
|
||||
|
||||
### Integration Testing First
|
||||
|
||||
**Prefer integration testing** when writing tests for a directory:
|
||||
|
||||
- ✅ **Import real project components** directly (including base components and siblings)
|
||||
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
|
||||
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
|
||||
- ❌ **DO NOT mock** sibling/child components in the same directory
|
||||
|
||||
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
|
||||
|
||||
## Core Principles
|
||||
|
||||
### 1. AAA Pattern (Arrange-Act-Assert)
|
||||
|
||||
Every test should clearly separate:
|
||||
|
||||
- **Arrange**: Setup test data and render component
|
||||
- **Act**: Perform user actions
|
||||
- **Assert**: Verify expected outcomes
|
||||
|
||||
### 2. Black-Box Testing
|
||||
|
||||
- Test observable behavior, not implementation details
|
||||
- Use semantic queries (getByRole, getByLabelText)
|
||||
- Avoid testing internal state directly
|
||||
- **Prefer pattern matching over hardcoded strings** in assertions:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid: hardcoded text assertions
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
|
||||
// ✅ Better: role-based queries
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
|
||||
// ✅ Better: pattern matching
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### 3. Single Behavior Per Test
|
||||
|
||||
Each test verifies ONE user-observable behavior:
|
||||
|
||||
```typescript
|
||||
// ✅ Good: One behavior
|
||||
it('should disable button when loading', () => {
|
||||
render(<Button loading />)
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
|
||||
// ❌ Bad: Multiple behaviors
|
||||
it('should handle loading state', () => {
|
||||
render(<Button loading />)
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button')).toHaveClass('loading')
|
||||
})
|
||||
```
|
||||
|
||||
### 4. Semantic Naming
|
||||
|
||||
Use `should <behavior> when <condition>`:
|
||||
|
||||
```typescript
|
||||
it('should show error message when validation fails')
|
||||
it('should call onSubmit when form is valid')
|
||||
it('should disable input when isReadOnly is true')
|
||||
```
|
||||
|
||||
## Required Test Scenarios
|
||||
|
||||
### Always Required (All Components)
|
||||
|
||||
1. **Rendering**: Component renders without crashing
|
||||
1. **Props**: Required props, optional props, default values
|
||||
1. **Edge Cases**: null, undefined, empty values, boundary conditions
|
||||
|
||||
### Conditional (When Present)
|
||||
|
||||
| Feature | Test Focus |
|
||||
|---------|-----------|
|
||||
| `useState` | Initial state, transitions, cleanup |
|
||||
| `useEffect` | Execution, dependencies, cleanup |
|
||||
| Event handlers | All onClick, onChange, onSubmit, keyboard |
|
||||
| API calls | Loading, success, error states |
|
||||
| Routing | Navigation, params, query strings |
|
||||
| `useCallback`/`useMemo` | Referential equality |
|
||||
| Context | Provider values, consumer behavior |
|
||||
| Forms | Validation, submission, error display |
|
||||
|
||||
## Coverage Goals (Per File)
|
||||
|
||||
For each test file generated, aim for:
|
||||
|
||||
- ✅ **100%** function coverage
|
||||
- ✅ **100%** statement coverage
|
||||
- ✅ **>95%** branch coverage
|
||||
- ✅ **>95%** line coverage
|
||||
|
||||
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
|
||||
|
||||
## Detailed Guides
|
||||
|
||||
For more detailed information, refer to:
|
||||
|
||||
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||
- `guides/mocking.md` - Mock patterns and best practices
|
||||
- `guides/async-testing.md` - Async operations and API calls
|
||||
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||
- `guides/common-patterns.md` - Frequently used testing patterns
|
||||
|
||||
## Authoritative References
|
||||
|
||||
### Primary Specification (MUST follow)
|
||||
|
||||
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
|
||||
|
||||
### Reference Examples in Codebase
|
||||
|
||||
- `web/utils/classnames.spec.ts` - Utility function tests
|
||||
- `web/app/components/base/button/index.spec.tsx` - Component tests
|
||||
- `web/__mocks__/provider-context.ts` - Mock factory example
|
||||
|
||||
### Project Configuration
|
||||
|
||||
- `web/jest.config.ts` - Jest configuration
|
||||
- `web/jest.setup.ts` - Test environment setup
|
||||
- `web/testing/analyze-component.js` - Component analysis tool
|
||||
|
|
@ -0,0 +1,345 @@
|
|||
# Async Testing Guide
|
||||
|
||||
## Core Async Patterns
|
||||
|
||||
### 1. waitFor - Wait for Condition
|
||||
|
||||
```typescript
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
|
||||
it('should load and display data', async () => {
|
||||
render(<DataComponent />)
|
||||
|
||||
// Wait for element to appear
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should hide loading spinner after load', async () => {
|
||||
render(<DataComponent />)
|
||||
|
||||
// Wait for element to disappear
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 2. findBy\* - Async Queries
|
||||
|
||||
```typescript
|
||||
it('should show user name after fetch', async () => {
|
||||
render(<UserProfile />)
|
||||
|
||||
// findBy returns a promise, auto-waits up to 1000ms
|
||||
const userName = await screen.findByText('John Doe')
|
||||
expect(userName).toBeInTheDocument()
|
||||
|
||||
// findByRole with options
|
||||
const button = await screen.findByRole('button', { name: /submit/i })
|
||||
expect(button).toBeEnabled()
|
||||
})
|
||||
```
|
||||
|
||||
### 3. userEvent for Async Interactions
|
||||
|
||||
```typescript
|
||||
import userEvent from '@testing-library/user-event'
|
||||
|
||||
it('should submit form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = jest.fn()
|
||||
|
||||
render(<Form onSubmit={onSubmit} />)
|
||||
|
||||
// userEvent methods are async
|
||||
await user.type(screen.getByLabelText('Email'), 'test@example.com')
|
||||
await user.click(screen.getByRole('button', { name: /submit/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Fake Timers
|
||||
|
||||
### When to Use Fake Timers
|
||||
|
||||
- Testing components with `setTimeout`/`setInterval`
|
||||
- Testing debounce/throttle behavior
|
||||
- Testing animations or delayed transitions
|
||||
- Testing polling or retry logic
|
||||
|
||||
### Basic Fake Timer Setup
|
||||
|
||||
```typescript
|
||||
describe('Debounced Search', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers()
|
||||
})
|
||||
|
||||
it('should debounce search input', async () => {
|
||||
const onSearch = jest.fn()
|
||||
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
||||
|
||||
// Type in the input
|
||||
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
|
||||
|
||||
// Search not called immediately
|
||||
expect(onSearch).not.toHaveBeenCalled()
|
||||
|
||||
// Advance timers
|
||||
jest.advanceTimersByTime(300)
|
||||
|
||||
// Now search is called
|
||||
expect(onSearch).toHaveBeenCalledWith('query')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Fake Timers with Async Code
|
||||
|
||||
```typescript
|
||||
it('should retry on failure', async () => {
|
||||
jest.useFakeTimers()
|
||||
const fetchData = jest.fn()
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockResolvedValueOnce({ data: 'success' })
|
||||
|
||||
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
|
||||
|
||||
// First call fails
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Advance timer for retry
|
||||
jest.advanceTimersByTime(1000)
|
||||
|
||||
// Second call succeeds
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(2)
|
||||
expect(screen.getByText('success')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
jest.useRealTimers()
|
||||
})
|
||||
```
|
||||
|
||||
### Common Fake Timer Utilities
|
||||
|
||||
```typescript
|
||||
// Run all pending timers
|
||||
jest.runAllTimers()
|
||||
|
||||
// Run only pending timers (not new ones created during execution)
|
||||
jest.runOnlyPendingTimers()
|
||||
|
||||
// Advance by specific time
|
||||
jest.advanceTimersByTime(1000)
|
||||
|
||||
// Get current fake time
|
||||
jest.now()
|
||||
|
||||
// Clear all timers
|
||||
jest.clearAllTimers()
|
||||
```
|
||||
|
||||
## API Testing Patterns
|
||||
|
||||
### Loading → Success → Error States
|
||||
|
||||
```typescript
|
||||
describe('DataFetcher', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should show loading state', () => {
|
||||
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
|
||||
const item1 = await screen.findByText('Item 1')
|
||||
const item2 = await screen.findByText('Item 2')
|
||||
expect(item1).toBeInTheDocument()
|
||||
expect(item2).toBeInTheDocument()
|
||||
|
||||
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should retry on error', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
|
||||
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Item 1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Mutations
|
||||
|
||||
```typescript
|
||||
it('should submit form and show success', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
|
||||
|
||||
render(<CreateItemForm />)
|
||||
|
||||
await user.type(screen.getByLabelText('Name'), 'New Item')
|
||||
await user.click(screen.getByRole('button', { name: /create/i }))
|
||||
|
||||
// Button should be disabled during submission
|
||||
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
|
||||
})
|
||||
```
|
||||
|
||||
## useEffect Testing
|
||||
|
||||
### Testing Effect Execution
|
||||
|
||||
```typescript
|
||||
it('should fetch data on mount', async () => {
|
||||
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
|
||||
|
||||
render(<ComponentWithEffect fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Effect Dependencies
|
||||
|
||||
```typescript
|
||||
it('should refetch when id changes', async () => {
|
||||
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
|
||||
|
||||
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledWith('1')
|
||||
})
|
||||
|
||||
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledWith('2')
|
||||
expect(fetchData).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Effect Cleanup
|
||||
|
||||
```typescript
|
||||
it('should cleanup subscription on unmount', () => {
|
||||
const subscribe = jest.fn()
|
||||
const unsubscribe = jest.fn()
|
||||
subscribe.mockReturnValue(unsubscribe)
|
||||
|
||||
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
||||
|
||||
expect(subscribe).toHaveBeenCalledTimes(1)
|
||||
|
||||
unmount()
|
||||
|
||||
expect(unsubscribe).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
```
|
||||
|
||||
## Common Async Pitfalls
|
||||
|
||||
### ❌ Don't: Forget to await
|
||||
|
||||
```typescript
|
||||
// Bad - test may pass even if assertion fails
|
||||
it('should load data', () => {
|
||||
render(<Component />)
|
||||
waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Good - properly awaited
|
||||
it('should load data', async () => {
|
||||
render(<Component />)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### ❌ Don't: Use multiple assertions in single waitFor
|
||||
|
||||
```typescript
|
||||
// Bad - if first assertion fails, won't know about second
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Title')).toBeInTheDocument()
|
||||
expect(screen.getByText('Description')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Good - separate waitFor or use findBy
|
||||
const title = await screen.findByText('Title')
|
||||
const description = await screen.findByText('Description')
|
||||
expect(title).toBeInTheDocument()
|
||||
expect(description).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### ❌ Don't: Mix fake timers with real async
|
||||
|
||||
```typescript
|
||||
// Bad - fake timers don't work well with real Promises
|
||||
jest.useFakeTimers()
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
}) // May timeout!
|
||||
|
||||
// Good - use runAllTimers or advanceTimersByTime
|
||||
jest.useFakeTimers()
|
||||
render(<Component />)
|
||||
jest.runAllTimers()
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
```
|
||||
|
|
@ -0,0 +1,449 @@
|
|||
# Common Testing Patterns
|
||||
|
||||
## Query Priority
|
||||
|
||||
Use queries in this order (most to least preferred):
|
||||
|
||||
```typescript
|
||||
// 1. getByRole - Most recommended (accessibility)
|
||||
screen.getByRole('button', { name: /submit/i })
|
||||
screen.getByRole('textbox', { name: /email/i })
|
||||
screen.getByRole('heading', { level: 1 })
|
||||
|
||||
// 2. getByLabelText - Form fields
|
||||
screen.getByLabelText('Email address')
|
||||
screen.getByLabelText(/password/i)
|
||||
|
||||
// 3. getByPlaceholderText - When no label
|
||||
screen.getByPlaceholderText('Search...')
|
||||
|
||||
// 4. getByText - Non-interactive elements
|
||||
screen.getByText('Welcome to Dify')
|
||||
screen.getByText(/loading/i)
|
||||
|
||||
// 5. getByDisplayValue - Current input value
|
||||
screen.getByDisplayValue('current value')
|
||||
|
||||
// 6. getByAltText - Images
|
||||
screen.getByAltText('Company logo')
|
||||
|
||||
// 7. getByTitle - Tooltip elements
|
||||
screen.getByTitle('Close')
|
||||
|
||||
// 8. getByTestId - Last resort only!
|
||||
screen.getByTestId('custom-element')
|
||||
```
|
||||
|
||||
## Event Handling Patterns
|
||||
|
||||
### Click Events
|
||||
|
||||
```typescript
|
||||
// Basic click
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// With userEvent (preferred for realistic interaction)
|
||||
const user = userEvent.setup()
|
||||
await user.click(screen.getByRole('button'))
|
||||
|
||||
// Double click
|
||||
await user.dblClick(screen.getByRole('button'))
|
||||
|
||||
// Right click
|
||||
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
|
||||
```
|
||||
|
||||
### Form Input
|
||||
|
||||
```typescript
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Type in input
|
||||
await user.type(screen.getByRole('textbox'), 'Hello World')
|
||||
|
||||
// Clear and type
|
||||
await user.clear(screen.getByRole('textbox'))
|
||||
await user.type(screen.getByRole('textbox'), 'New value')
|
||||
|
||||
// Select option
|
||||
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
|
||||
|
||||
// Check checkbox
|
||||
await user.click(screen.getByRole('checkbox'))
|
||||
|
||||
// Upload file
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
```
|
||||
|
||||
### Keyboard Events
|
||||
|
||||
```typescript
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Press Enter
|
||||
await user.keyboard('{Enter}')
|
||||
|
||||
// Press Escape
|
||||
await user.keyboard('{Escape}')
|
||||
|
||||
// Keyboard shortcut
|
||||
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
|
||||
|
||||
// Tab navigation
|
||||
await user.tab()
|
||||
|
||||
// Arrow keys
|
||||
await user.keyboard('{ArrowDown}')
|
||||
await user.keyboard('{ArrowUp}')
|
||||
```
|
||||
|
||||
## Component State Testing
|
||||
|
||||
### Testing State Transitions
|
||||
|
||||
```typescript
|
||||
describe('Counter', () => {
|
||||
it('should increment count', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<Counter initialCount={0} />)
|
||||
|
||||
// Initial state
|
||||
expect(screen.getByText('Count: 0')).toBeInTheDocument()
|
||||
|
||||
// Trigger transition
|
||||
await user.click(screen.getByRole('button', { name: /increment/i }))
|
||||
|
||||
// New state
|
||||
expect(screen.getByText('Count: 1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Controlled Components
|
||||
|
||||
```typescript
|
||||
describe('ControlledInput', () => {
|
||||
it('should call onChange with new value', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleChange = jest.fn()
|
||||
|
||||
render(<ControlledInput value="" onChange={handleChange} />)
|
||||
|
||||
await user.type(screen.getByRole('textbox'), 'a')
|
||||
|
||||
expect(handleChange).toHaveBeenCalledWith('a')
|
||||
})
|
||||
|
||||
it('should display controlled value', () => {
|
||||
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
|
||||
|
||||
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Conditional Rendering Testing
|
||||
|
||||
```typescript
|
||||
describe('ConditionalComponent', () => {
|
||||
it('should show loading state', () => {
|
||||
render(<DataDisplay isLoading={true} data={null} />)
|
||||
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show error state', () => {
|
||||
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
|
||||
|
||||
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data when loaded', () => {
|
||||
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
|
||||
|
||||
expect(screen.getByText('Test')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show empty state when no data', () => {
|
||||
render(<DataDisplay isLoading={false} data={[]} />)
|
||||
|
||||
expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## List Rendering Testing
|
||||
|
||||
```typescript
|
||||
describe('ItemList', () => {
|
||||
const items = [
|
||||
{ id: '1', name: 'Item 1' },
|
||||
{ id: '2', name: 'Item 2' },
|
||||
{ id: '3', name: 'Item 3' },
|
||||
]
|
||||
|
||||
it('should render all items', () => {
|
||||
render(<ItemList items={items} />)
|
||||
|
||||
expect(screen.getAllByRole('listitem')).toHaveLength(3)
|
||||
items.forEach(item => {
|
||||
expect(screen.getByText(item.name)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle item selection', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = jest.fn()
|
||||
|
||||
render(<ItemList items={items} onSelect={onSelect} />)
|
||||
|
||||
await user.click(screen.getByText('Item 2'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(items[1])
|
||||
})
|
||||
|
||||
it('should handle empty list', () => {
|
||||
render(<ItemList items={[]} />)
|
||||
|
||||
expect(screen.getByText(/no items/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Modal/Dialog Testing
|
||||
|
||||
```typescript
|
||||
describe('Modal', () => {
|
||||
it('should not render when closed', () => {
|
||||
render(<Modal isOpen={false} onClose={jest.fn()} />)
|
||||
|
||||
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render when open', () => {
|
||||
render(<Modal isOpen={true} onClose={jest.fn()} />)
|
||||
|
||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onClose when clicking overlay', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClose = jest.fn()
|
||||
|
||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||
|
||||
await user.click(screen.getByTestId('modal-overlay'))
|
||||
|
||||
expect(handleClose).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onClose when pressing Escape', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClose = jest.fn()
|
||||
|
||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||
|
||||
await user.keyboard('{Escape}')
|
||||
|
||||
expect(handleClose).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should trap focus inside modal', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<Modal isOpen={true} onClose={jest.fn()}>
|
||||
<button>First</button>
|
||||
<button>Second</button>
|
||||
</Modal>
|
||||
)
|
||||
|
||||
// Focus should cycle within modal
|
||||
await user.tab()
|
||||
expect(screen.getByText('First')).toHaveFocus()
|
||||
|
||||
await user.tab()
|
||||
expect(screen.getByText('Second')).toHaveFocus()
|
||||
|
||||
await user.tab()
|
||||
expect(screen.getByText('First')).toHaveFocus() // Cycles back
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Form Testing
|
||||
|
||||
```typescript
|
||||
describe('LoginForm', () => {
|
||||
it('should submit valid form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = jest.fn()
|
||||
|
||||
render(<LoginForm onSubmit={onSubmit} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
|
||||
await user.type(screen.getByLabelText(/password/i), 'password123')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(onSubmit).toHaveBeenCalledWith({
|
||||
email: 'test@example.com',
|
||||
password: 'password123',
|
||||
})
|
||||
})
|
||||
|
||||
it('should show validation errors', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<LoginForm onSubmit={jest.fn()} />)
|
||||
|
||||
// Submit empty form
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should validate email format', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<LoginForm onSubmit={jest.fn()} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should disable submit button while submitting', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
||||
|
||||
render(<LoginForm onSubmit={onSubmit} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
|
||||
await user.type(screen.getByLabelText(/password/i), 'password123')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Data-Driven Tests with test.each
|
||||
|
||||
```typescript
|
||||
describe('StatusBadge', () => {
|
||||
test.each([
|
||||
['success', 'bg-green-500'],
|
||||
['warning', 'bg-yellow-500'],
|
||||
['error', 'bg-red-500'],
|
||||
['info', 'bg-blue-500'],
|
||||
])('should apply correct class for %s status', (status, expectedClass) => {
|
||||
render(<StatusBadge status={status} />)
|
||||
|
||||
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
|
||||
})
|
||||
|
||||
test.each([
|
||||
{ input: null, expected: 'Unknown' },
|
||||
{ input: undefined, expected: 'Unknown' },
|
||||
{ input: '', expected: 'Unknown' },
|
||||
{ input: 'invalid', expected: 'Unknown' },
|
||||
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
|
||||
render(<StatusBadge status={input} />)
|
||||
|
||||
expect(screen.getByText(expected)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
```typescript
|
||||
// Print entire DOM
|
||||
screen.debug()
|
||||
|
||||
// Print specific element
|
||||
screen.debug(screen.getByRole('button'))
|
||||
|
||||
// Log testing playground URL
|
||||
screen.logTestingPlaygroundURL()
|
||||
|
||||
// Pretty print DOM
|
||||
import { prettyDOM } from '@testing-library/react'
|
||||
console.log(prettyDOM(screen.getByRole('dialog')))
|
||||
|
||||
// Check available roles
|
||||
import { getRoles } from '@testing-library/react'
|
||||
console.log(getRoles(container))
|
||||
```
|
||||
|
||||
## Common Mistakes to Avoid
|
||||
|
||||
### ❌ Don't Use Implementation Details
|
||||
|
||||
```typescript
|
||||
// Bad - testing implementation
|
||||
expect(component.state.isOpen).toBe(true)
|
||||
expect(wrapper.find('.internal-class').length).toBe(1)
|
||||
|
||||
// Good - testing behavior
|
||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### ❌ Don't Forget Cleanup
|
||||
|
||||
```typescript
|
||||
// Bad - may leak state between tests
|
||||
it('test 1', () => {
|
||||
render(<Component />)
|
||||
})
|
||||
|
||||
// Good - cleanup is automatic with RTL, but reset mocks
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
```
|
||||
|
||||
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
|
||||
|
||||
```typescript
|
||||
// ❌ Bad - hardcoded strings are brittle
|
||||
expect(screen.getByText('Submit Form')).toBeInTheDocument()
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - role-based queries (most semantic)
|
||||
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - pattern matching (flexible)
|
||||
expect(screen.getByText(/submit/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - test behavior, not exact UI text
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(screen.getByRole('alert')).toBeInTheDocument()
|
||||
```
|
||||
|
||||
**Why prefer black-box assertions?**
|
||||
|
||||
- Text content may change (i18n, copy updates)
|
||||
- Role-based queries test accessibility
|
||||
- Pattern matching is resilient to minor changes
|
||||
- Tests focus on behavior, not implementation details
|
||||
|
||||
### ❌ Don't Assert on Absence Without Query
|
||||
|
||||
```typescript
|
||||
// Bad - throws if not found
|
||||
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
|
||||
|
||||
// Good - use queryBy for absence assertions
|
||||
expect(screen.queryByText('Error')).not.toBeInTheDocument()
|
||||
```
|
||||
|
|
@ -0,0 +1,523 @@
|
|||
# Domain-Specific Component Testing
|
||||
|
||||
This guide covers testing patterns for Dify's domain-specific components.
|
||||
|
||||
## Workflow Components (`workflow/`)
|
||||
|
||||
Workflow components handle node configuration, data flow, and graph operations.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **Node Configuration**
|
||||
1. **Data Validation**
|
||||
1. **Variable Passing**
|
||||
1. **Edge Connections**
|
||||
1. **Error Handling**
|
||||
|
||||
### Example: Node Configuration Panel
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import NodeConfigPanel from './node-config-panel'
|
||||
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
||||
|
||||
// Mock workflow context
|
||||
jest.mock('@/app/components/workflow/hooks', () => ({
|
||||
useWorkflowStore: () => mockWorkflowStore,
|
||||
useNodesInteractions: () => mockNodesInteractions,
|
||||
}))
|
||||
|
||||
let mockWorkflowStore = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
updateNode: jest.fn(),
|
||||
}
|
||||
|
||||
let mockNodesInteractions = {
|
||||
handleNodeSelect: jest.fn(),
|
||||
handleNodeDelete: jest.fn(),
|
||||
}
|
||||
|
||||
describe('NodeConfigPanel', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockWorkflowStore = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
updateNode: jest.fn(),
|
||||
}
|
||||
})
|
||||
|
||||
describe('Node Configuration', () => {
|
||||
it('should render node type selector', () => {
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update node config on change', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
|
||||
|
||||
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
|
||||
node.id,
|
||||
expect.objectContaining({ model: 'gpt-4' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Data Validation', () => {
|
||||
it('should show error for invalid input', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'code' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
// Enter invalid code
|
||||
const codeInput = screen.getByLabelText(/code/i)
|
||||
await user.clear(codeInput)
|
||||
await user.type(codeInput, 'invalid syntax {{{')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should validate required fields', async () => {
|
||||
const node = createMockNode({ type: 'http', data: { url: '' } })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Variable Passing', () => {
|
||||
it('should display available variables from upstream nodes', () => {
|
||||
const upstreamNode = createMockNode({
|
||||
id: 'node-1',
|
||||
type: 'start',
|
||||
data: { outputs: [{ name: 'user_input', type: 'string' }] },
|
||||
})
|
||||
const currentNode = createMockNode({
|
||||
id: 'node-2',
|
||||
type: 'llm',
|
||||
})
|
||||
|
||||
mockWorkflowStore.nodes = [upstreamNode, currentNode]
|
||||
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
|
||||
|
||||
render(<NodeConfigPanel node={currentNode} />)
|
||||
|
||||
// Variable selector should show upstream variables
|
||||
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
|
||||
|
||||
expect(screen.getByText('user_input')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should insert variable into prompt template', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
// Click variable button
|
||||
await user.click(screen.getByRole('button', { name: /insert variable/i }))
|
||||
await user.click(screen.getByText('user_input'))
|
||||
|
||||
const promptInput = screen.getByLabelText(/prompt/i)
|
||||
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Dataset Components (`dataset/`)
|
||||
|
||||
Dataset components handle file uploads, data display, and search/filter operations.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **File Upload**
|
||||
1. **File Type Validation**
|
||||
1. **Pagination**
|
||||
1. **Search & Filtering**
|
||||
1. **Data Format Handling**
|
||||
|
||||
### Example: Document Uploader
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import DocumentUploader from './document-uploader'
|
||||
|
||||
jest.mock('@/service/datasets', () => ({
|
||||
uploadDocument: jest.fn(),
|
||||
parseDocument: jest.fn(),
|
||||
}))
|
||||
|
||||
import * as datasetService from '@/service/datasets'
|
||||
const mockedService = datasetService as jest.Mocked<typeof datasetService>
|
||||
|
||||
describe('DocumentUploader', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('File Upload', () => {
|
||||
it('should accept valid file types', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpload = jest.fn()
|
||||
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
||||
|
||||
render(<DocumentUploader onUpload={onUpload} />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const input = screen.getByLabelText(/upload/i)
|
||||
|
||||
await user.upload(input, file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
|
||||
expect.any(FormData)
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should reject invalid file types', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
|
||||
const input = screen.getByLabelText(/upload/i)
|
||||
|
||||
await user.upload(input, file)
|
||||
|
||||
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
|
||||
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show upload progress', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Mock upload with progress
|
||||
mockedService.uploadDocument.mockImplementation(() => {
|
||||
return new Promise((resolve) => {
|
||||
setTimeout(() => resolve({ id: 'doc-1' }), 100)
|
||||
})
|
||||
})
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
expect(screen.getByRole('progressbar')).toBeInTheDocument()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle upload failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should allow retry after failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.uploadDocument
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockResolvedValueOnce({ id: 'doc-1' })
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /retry/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Example: Document List with Pagination
|
||||
|
||||
```typescript
|
||||
describe('DocumentList', () => {
|
||||
describe('Pagination', () => {
|
||||
it('should load first page on mount', async () => {
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '1', name: 'Doc 1' }],
|
||||
total: 50,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
|
||||
})
|
||||
|
||||
it('should navigate to next page', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '1', name: 'Doc 1' }],
|
||||
total: 50,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '11', name: 'Doc 11' }],
|
||||
total: 50,
|
||||
page: 2,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /next/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 11')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Search & Filtering', () => {
|
||||
it('should filter by search query', async () => {
|
||||
const user = userEvent.setup()
|
||||
jest.useFakeTimers()
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
||||
|
||||
// Debounce
|
||||
jest.advanceTimersByTime(300)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
||||
'ds-1',
|
||||
expect.objectContaining({ search: 'test query' })
|
||||
)
|
||||
})
|
||||
|
||||
jest.useRealTimers()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration Components (`app/configuration/`, `config/`)
|
||||
|
||||
Configuration components handle forms, validation, and data persistence.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **Form Validation**
|
||||
1. **Save/Reset**
|
||||
1. **Required vs Optional Fields**
|
||||
1. **Configuration Persistence**
|
||||
1. **Error Feedback**
|
||||
|
||||
### Example: App Configuration Form
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import AppConfigForm from './app-config-form'
|
||||
|
||||
jest.mock('@/service/apps', () => ({
|
||||
updateAppConfig: jest.fn(),
|
||||
getAppConfig: jest.fn(),
|
||||
}))
|
||||
|
||||
import * as appService from '@/service/apps'
|
||||
const mockedService = appService as jest.Mocked<typeof appService>
|
||||
|
||||
describe('AppConfigForm', () => {
|
||||
const defaultConfig = {
|
||||
name: 'My App',
|
||||
description: '',
|
||||
icon: 'default',
|
||||
openingStatement: '',
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
||||
})
|
||||
|
||||
describe('Form Validation', () => {
|
||||
it('should require app name', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Clear name field
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
|
||||
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate name length', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Enter very long name
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
|
||||
|
||||
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should allow empty optional fields', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockResolvedValue({ success: true })
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Leave description empty (optional)
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.updateAppConfig).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Save/Reset Functionality', () => {
|
||||
it('should save configuration', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockResolvedValue({ success: true })
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'Updated App')
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
|
||||
'app-1',
|
||||
expect.objectContaining({ name: 'Updated App' })
|
||||
)
|
||||
})
|
||||
|
||||
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should reset to default values', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Make changes
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
|
||||
|
||||
// Reset
|
||||
await user.click(screen.getByRole('button', { name: /reset/i }))
|
||||
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
it('should show unsaved changes warning', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Make changes
|
||||
await user.type(screen.getByLabelText(/name/i), ' Updated')
|
||||
|
||||
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should show error on save failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
|
@ -0,0 +1,353 @@
|
|||
# Mocking Guide for Dify Frontend Tests
|
||||
|
||||
## ⚠️ Important: What NOT to Mock
|
||||
|
||||
### DO NOT Mock Base Components
|
||||
|
||||
**Never mock components from `@/app/components/base/`** such as:
|
||||
|
||||
- `Loading`, `Spinner`
|
||||
- `Button`, `Input`, `Select`
|
||||
- `Tooltip`, `Modal`, `Dropdown`
|
||||
- `Icon`, `Badge`, `Tag`
|
||||
|
||||
**Why?**
|
||||
|
||||
- Base components will have their own dedicated tests
|
||||
- Mocking them creates false positives (tests pass but real integration fails)
|
||||
- Using real components tests actual integration behavior
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components
|
||||
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Button from '@/app/components/base/button'
|
||||
// They will render normally in tests
|
||||
```
|
||||
|
||||
### What TO Mock
|
||||
|
||||
Only mock these categories:
|
||||
|
||||
1. **API services** (`@/service/*`) - Network calls
|
||||
1. **Complex context providers** - When setup is too difficult
|
||||
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
|
||||
1. **i18n** - Always mock to return keys
|
||||
|
||||
## Mock Placement
|
||||
|
||||
| Location | Purpose |
|
||||
|----------|---------|
|
||||
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
|
||||
| Test file | Test-specific mocks, inline with `jest.mock()` |
|
||||
|
||||
## Essential Mocks
|
||||
|
||||
### 1. i18n (Always Required)
|
||||
|
||||
```typescript
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
```
|
||||
|
||||
### 2. Next.js Router
|
||||
|
||||
```typescript
|
||||
const mockPush = jest.fn()
|
||||
const mockReplace = jest.fn()
|
||||
|
||||
jest.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
replace: mockReplace,
|
||||
back: jest.fn(),
|
||||
prefetch: jest.fn(),
|
||||
}),
|
||||
usePathname: () => '/current-path',
|
||||
useSearchParams: () => new URLSearchParams('?key=value'),
|
||||
}))
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should navigate on click', () => {
|
||||
render(<Component />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(mockPush).toHaveBeenCalledWith('/expected-path')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Portal Components (with Shared State)
|
||||
|
||||
```typescript
|
||||
// ⚠️ Important: Use shared state for components that depend on each other
|
||||
let mockPortalOpenState = false
|
||||
|
||||
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
||||
mockPortalOpenState = open || false // Update shared state
|
||||
return <div data-testid="portal" data-open={open}>{children}</div>
|
||||
},
|
||||
PortalToFollowElemContent: ({ children }: any) => {
|
||||
// ✅ Matches actual: returns null when portal is closed
|
||||
if (!mockPortalOpenState) return null
|
||||
return <div data-testid="portal-content">{children}</div>
|
||||
},
|
||||
PortalToFollowElemTrigger: ({ children }: any) => (
|
||||
<div data-testid="portal-trigger">{children}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockPortalOpenState = false // ✅ Reset shared state
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 4. API Service Mocks
|
||||
|
||||
```typescript
|
||||
import * as api from '@/service/api'
|
||||
|
||||
jest.mock('@/service/api')
|
||||
|
||||
const mockedApi = api as jest.Mocked<typeof api>
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
|
||||
// Setup default mock implementation
|
||||
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
|
||||
|
||||
render(<Component />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
render(<Component />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 5. HTTP Mocking with Nock
|
||||
|
||||
```typescript
|
||||
import nock from 'nock'
|
||||
|
||||
const GITHUB_HOST = 'https://api.github.com'
|
||||
const GITHUB_PATH = '/repos/owner/repo'
|
||||
|
||||
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
|
||||
return nock(GITHUB_HOST)
|
||||
.get(GITHUB_PATH)
|
||||
.delay(delayMs)
|
||||
.reply(status, body)
|
||||
}
|
||||
|
||||
describe('GithubComponent', () => {
|
||||
afterEach(() => {
|
||||
nock.cleanAll()
|
||||
})
|
||||
|
||||
it('should display repo info', async () => {
|
||||
mockGithubApi(200, { name: 'dify', stars: 1000 })
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('dify')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API error', async () => {
|
||||
mockGithubApi(500, { message: 'Server error' })
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 6. Context Providers
|
||||
|
||||
```typescript
|
||||
import { ProviderContext } from '@/context/provider-context'
|
||||
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
|
||||
|
||||
describe('Component with Context', () => {
|
||||
it('should render for free plan', () => {
|
||||
const mockContext = createMockPlan('sandbox')
|
||||
|
||||
render(
|
||||
<ProviderContext.Provider value={mockContext}>
|
||||
<Component />
|
||||
</ProviderContext.Provider>
|
||||
)
|
||||
|
||||
expect(screen.getByText('Upgrade')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render for pro plan', () => {
|
||||
const mockContext = createMockPlan('professional')
|
||||
|
||||
render(
|
||||
<ProviderContext.Provider value={mockContext}>
|
||||
<Component />
|
||||
</ProviderContext.Provider>
|
||||
)
|
||||
|
||||
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 7. SWR / React Query
|
||||
|
||||
```typescript
|
||||
// SWR
|
||||
jest.mock('swr', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(),
|
||||
}))
|
||||
|
||||
import useSWR from 'swr'
|
||||
const mockedUseSWR = useSWR as jest.Mock
|
||||
|
||||
describe('Component with SWR', () => {
|
||||
it('should show loading state', () => {
|
||||
mockedUseSWR.mockReturnValue({
|
||||
data: undefined,
|
||||
error: undefined,
|
||||
isLoading: true,
|
||||
})
|
||||
|
||||
render(<Component />)
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// React Query
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
})
|
||||
|
||||
const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
const queryClient = createTestQueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Mock Best Practices
|
||||
|
||||
### ✅ DO
|
||||
|
||||
1. **Use real base components** - Import from `@/app/components/base/` directly
|
||||
1. **Use real project components** - Prefer importing over mocking
|
||||
1. **Reset mocks in `beforeEach`**, not `afterEach`
|
||||
1. **Match actual component behavior** in mocks (when mocking is necessary)
|
||||
1. **Use factory functions** for complex mock data
|
||||
1. **Import actual types** for type safety
|
||||
1. **Reset shared mock state** in `beforeEach`
|
||||
|
||||
### ❌ DON'T
|
||||
|
||||
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
1. Don't mock components you can import directly
|
||||
1. Don't create overly simplified mocks that miss conditional logic
|
||||
1. Don't forget to clean up nock after each test
|
||||
1. Don't use `any` types in mocks without necessity
|
||||
|
||||
### Mock Decision Tree
|
||||
|
||||
```
|
||||
Need to use a component in test?
|
||||
│
|
||||
├─ Is it from @/app/components/base/*?
|
||||
│ └─ YES → Import real component, DO NOT mock
|
||||
│
|
||||
├─ Is it a project component?
|
||||
│ └─ YES → Prefer importing real component
|
||||
│ Only mock if setup is extremely complex
|
||||
│
|
||||
├─ Is it an API service (@/service/*)?
|
||||
│ └─ YES → Mock it
|
||||
│
|
||||
├─ Is it a third-party lib with side effects?
|
||||
│ └─ YES → Mock it (next/navigation, external SDKs)
|
||||
│
|
||||
└─ Is it i18n?
|
||||
└─ YES → Mock to return keys
|
||||
```
|
||||
|
||||
## Factory Function Pattern
|
||||
|
||||
```typescript
|
||||
// __mocks__/data-factories.ts
|
||||
import type { User, Project } from '@/types'
|
||||
|
||||
export const createMockUser = (overrides: Partial<User> = {}): User => ({
|
||||
id: 'user-1',
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
role: 'member',
|
||||
createdAt: new Date().toISOString(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
|
||||
id: 'project-1',
|
||||
name: 'Test Project',
|
||||
description: 'A test project',
|
||||
owner: createMockUser(),
|
||||
members: [],
|
||||
createdAt: new Date().toISOString(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
// Usage in tests
|
||||
it('should display project owner', () => {
|
||||
const project = createMockProject({
|
||||
owner: createMockUser({ name: 'John Doe' }),
|
||||
})
|
||||
|
||||
render(<ProjectCard project={project} />)
|
||||
expect(screen.getByText('John Doe')).toBeInTheDocument()
|
||||
})
|
||||
```
|
||||
|
|
@ -0,0 +1,269 @@
|
|||
# Testing Workflow Guide
|
||||
|
||||
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
|
||||
|
||||
## Scope Clarification
|
||||
|
||||
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
|
||||
|
||||
| Scope | Rule |
|
||||
|-------|------|
|
||||
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
|
||||
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
|
||||
|
||||
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
|
||||
|
||||
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
|
||||
|
||||
### Why Incremental?
|
||||
|
||||
| Batch Approach (❌) | Incremental Approach (✅) |
|
||||
|---------------------|---------------------------|
|
||||
| Generate 5+ tests at once | Generate 1 test at a time |
|
||||
| Run tests only at the end | Run test immediately after each file |
|
||||
| Multiple failures compound | Single point of failure, easy to debug |
|
||||
| Hard to identify root cause | Clear cause-effect relationship |
|
||||
| Mock issues affect many files | Mock issues caught early |
|
||||
| Messy git history | Clean, atomic commits possible |
|
||||
|
||||
## Single File Workflow
|
||||
|
||||
When testing a **single component, hook, or utility**:
|
||||
|
||||
```
|
||||
1. Read source code completely
|
||||
2. Run `pnpm analyze-component <path>` (if available)
|
||||
3. Check complexity score and features detected
|
||||
4. Write the test file
|
||||
5. Run test: `pnpm test -- <file>.spec.tsx`
|
||||
6. Fix any failures
|
||||
7. Verify coverage meets goals (100% function, >95% branch)
|
||||
```
|
||||
|
||||
## Directory/Multi-File Workflow (MUST FOLLOW)
|
||||
|
||||
When testing a **directory or multiple files**, follow this strict workflow:
|
||||
|
||||
### Step 1: Analyze and Plan
|
||||
|
||||
1. **List all files** that need tests in the directory
|
||||
1. **Categorize by complexity**:
|
||||
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
|
||||
- 🟡 **Medium**: Components with state, effects, or event handlers
|
||||
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
|
||||
1. **Order by dependency**: Test dependencies before dependents
|
||||
1. **Create a todo list** to track progress
|
||||
|
||||
### Step 2: Determine Processing Order
|
||||
|
||||
Process files in this recommended order:
|
||||
|
||||
```
|
||||
1. Utility functions (simplest, no React)
|
||||
2. Custom hooks (isolated logic)
|
||||
3. Simple presentational components (few/no props)
|
||||
4. Medium complexity components (state, effects)
|
||||
5. Complex components (API, routing, many deps)
|
||||
6. Container/index components (integration tests - last)
|
||||
```
|
||||
|
||||
**Rationale**:
|
||||
|
||||
- Simpler files help establish mock patterns
|
||||
- Hooks used by components should be tested first
|
||||
- Integration tests (index files) depend on child components working
|
||||
|
||||
### Step 3: Process Each File Incrementally
|
||||
|
||||
**For EACH file in the ordered list:**
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────┐
|
||||
│ 1. Write test file │
|
||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
||||
│ 3. If FAIL → Fix immediately, re-run │
|
||||
│ 4. If PASS → Mark complete in todo list │
|
||||
│ 5. ONLY THEN proceed to next file │
|
||||
└─────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**DO NOT proceed to the next file until the current one passes.**
|
||||
|
||||
### Step 4: Final Verification
|
||||
|
||||
After all individual tests pass:
|
||||
|
||||
```bash
|
||||
# Run all tests in the directory together
|
||||
pnpm test -- path/to/directory/
|
||||
|
||||
# Check coverage
|
||||
pnpm test -- --coverage path/to/directory/
|
||||
```
|
||||
|
||||
## Component Complexity Guidelines
|
||||
|
||||
Use `pnpm analyze-component <path>` to assess complexity before testing.
|
||||
|
||||
### 🔴 Very Complex Components (Complexity > 50)
|
||||
|
||||
**Consider refactoring BEFORE testing:**
|
||||
|
||||
- Break component into smaller, testable pieces
|
||||
- Extract complex logic into custom hooks
|
||||
- Separate container and presentational layers
|
||||
|
||||
**If testing as-is:**
|
||||
|
||||
- Use integration tests for complex workflows
|
||||
- Use `test.each()` for data-driven testing
|
||||
- Multiple `describe` blocks for organization
|
||||
- Consider testing major sections separately
|
||||
|
||||
### 🟡 Medium Complexity (Complexity 30-50)
|
||||
|
||||
- Group related tests in `describe` blocks
|
||||
- Test integration scenarios between internal parts
|
||||
- Focus on state transitions and side effects
|
||||
- Use helper functions to reduce test complexity
|
||||
|
||||
### 🟢 Simple Components (Complexity < 30)
|
||||
|
||||
- Standard test structure
|
||||
- Focus on props, rendering, and edge cases
|
||||
- Usually straightforward to test
|
||||
|
||||
### 📏 Large Files (500+ lines)
|
||||
|
||||
Regardless of complexity score:
|
||||
|
||||
- **Strongly consider refactoring** before testing
|
||||
- If testing as-is, test major sections separately
|
||||
- Create helper functions for test setup
|
||||
- May need multiple test files
|
||||
|
||||
## Todo List Format
|
||||
|
||||
When testing multiple files, use a todo list like this:
|
||||
|
||||
```
|
||||
Testing: path/to/directory/
|
||||
|
||||
Ordered by complexity (simple → complex):
|
||||
|
||||
☐ utils/helper.ts [utility, simple]
|
||||
☐ hooks/use-custom-hook.ts [hook, simple]
|
||||
☐ empty-state.tsx [component, simple]
|
||||
☐ item-card.tsx [component, medium]
|
||||
☐ list.tsx [component, complex]
|
||||
☐ index.tsx [integration]
|
||||
|
||||
Progress: 0/6 complete
|
||||
```
|
||||
|
||||
Update status as you complete each:
|
||||
|
||||
- ☐ → ⏳ (in progress)
|
||||
- ⏳ → ✅ (complete and verified)
|
||||
- ⏳ → ❌ (blocked, needs attention)
|
||||
|
||||
## When to Stop and Verify
|
||||
|
||||
**Always run tests after:**
|
||||
|
||||
- Completing a test file
|
||||
- Making changes to fix a failure
|
||||
- Modifying shared mocks
|
||||
- Updating test utilities or helpers
|
||||
|
||||
**Signs you should pause:**
|
||||
|
||||
- More than 2 consecutive test failures
|
||||
- Mock-related errors appearing
|
||||
- Unclear why a test is failing
|
||||
- Test passing but coverage unexpectedly low
|
||||
|
||||
## Common Pitfalls to Avoid
|
||||
|
||||
### ❌ Don't: Generate Everything First
|
||||
|
||||
```
|
||||
# BAD: Writing all files then testing
|
||||
Write component-a.spec.tsx
|
||||
Write component-b.spec.tsx
|
||||
Write component-c.spec.tsx
|
||||
Write component-d.spec.tsx
|
||||
Run pnpm test ← Multiple failures, hard to debug
|
||||
```
|
||||
|
||||
### ✅ Do: Verify Each Step
|
||||
|
||||
```
|
||||
# GOOD: Incremental with verification
|
||||
Write component-a.spec.tsx
|
||||
Run pnpm test -- component-a.spec.tsx ✅
|
||||
Write component-b.spec.tsx
|
||||
Run pnpm test -- component-b.spec.tsx ✅
|
||||
...continue...
|
||||
```
|
||||
|
||||
### ❌ Don't: Skip Verification for "Simple" Components
|
||||
|
||||
Even simple components can have:
|
||||
|
||||
- Import errors
|
||||
- Missing mock setup
|
||||
- Incorrect assumptions about props
|
||||
|
||||
**Always verify, regardless of perceived simplicity.**
|
||||
|
||||
### ❌ Don't: Continue When Tests Fail
|
||||
|
||||
Failing tests compound:
|
||||
|
||||
- A mock issue in file A affects files B, C, D
|
||||
- Fixing A later requires revisiting all dependent tests
|
||||
- Time wasted on debugging cascading failures
|
||||
|
||||
**Fix failures immediately before proceeding.**
|
||||
|
||||
## Integration with Claude's Todo Feature
|
||||
|
||||
When using Claude for multi-file testing:
|
||||
|
||||
1. **Ask Claude to create a todo list** before starting
|
||||
1. **Request one file at a time** or ensure Claude processes incrementally
|
||||
1. **Verify each test passes** before asking for the next
|
||||
1. **Mark todos complete** as you progress
|
||||
|
||||
Example prompt:
|
||||
|
||||
```
|
||||
Test all components in `path/to/directory/`.
|
||||
First, analyze the directory and create a todo list ordered by complexity.
|
||||
Then, process ONE file at a time, waiting for my confirmation that tests pass
|
||||
before proceeding to the next.
|
||||
```
|
||||
|
||||
## Summary Checklist
|
||||
|
||||
Before starting multi-file testing:
|
||||
|
||||
- [ ] Listed all files needing tests
|
||||
- [ ] Ordered by complexity (simple → complex)
|
||||
- [ ] Created todo list for tracking
|
||||
- [ ] Understand dependencies between files
|
||||
|
||||
During testing:
|
||||
|
||||
- [ ] Processing ONE file at a time
|
||||
- [ ] Running tests after EACH file
|
||||
- [ ] Fixing failures BEFORE proceeding
|
||||
- [ ] Updating todo list progress
|
||||
|
||||
After completion:
|
||||
|
||||
- [ ] All individual tests pass
|
||||
- [ ] Full directory test run passes
|
||||
- [ ] Coverage goals met
|
||||
- [ ] Todo list shows all complete
|
||||
|
|
@ -0,0 +1,289 @@
|
|||
/**
|
||||
* Test Template for React Components
|
||||
*
|
||||
* WHY THIS STRUCTURE?
|
||||
* - Organized sections make tests easy to navigate and maintain
|
||||
* - Mocks at top ensure consistent test isolation
|
||||
* - Factory functions reduce duplication and improve readability
|
||||
* - describe blocks group related scenarios for better debugging
|
||||
*
|
||||
* INSTRUCTIONS:
|
||||
* 1. Replace `ComponentName` with your component name
|
||||
* 2. Update import path
|
||||
* 3. Add/remove test sections based on component features (use analyze-component)
|
||||
* 4. Follow AAA pattern: Arrange → Act → Assert
|
||||
*
|
||||
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
|
||||
*/
|
||||
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
// import ComponentName from './index'
|
||||
|
||||
// ============================================================================
|
||||
// Mocks
|
||||
// ============================================================================
|
||||
// WHY: Mocks must be hoisted to top of file (Jest requirement).
|
||||
// They run BEFORE imports, so keep them before component imports.
|
||||
|
||||
// i18n (always required in Dify)
|
||||
// WHY: Returns key instead of translation so tests don't depend on i18n files
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = jest.fn()
|
||||
// jest.mock('next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
||||
// API services (if component fetches data)
|
||||
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
||||
// jest.mock('@/service/api')
|
||||
// import * as api from '@/service/api'
|
||||
// const mockedApi = api as jest.Mocked<typeof api>
|
||||
|
||||
// Shared mock state (for portal/dropdown components)
|
||||
// WHY: Portal components like PortalToFollowElem need shared state between
|
||||
// parent and child mocks to correctly simulate open/close behavior
|
||||
// let mockOpenState = false
|
||||
|
||||
// ============================================================================
|
||||
// Test Data Factories
|
||||
// ============================================================================
|
||||
// WHY FACTORIES?
|
||||
// - Avoid hard-coded test data scattered across tests
|
||||
// - Easy to create variations with overrides
|
||||
// - Type-safe when using actual types from source
|
||||
// - Single source of truth for default test values
|
||||
|
||||
// const createMockProps = (overrides = {}) => ({
|
||||
// // Default props that make component render successfully
|
||||
// ...overrides,
|
||||
// })
|
||||
|
||||
// const createMockItem = (overrides = {}) => ({
|
||||
// id: 'item-1',
|
||||
// name: 'Test Item',
|
||||
// ...overrides,
|
||||
// })
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
// const renderComponent = (props = {}) => {
|
||||
// return render(<ComponentName {...createMockProps(props)} />)
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('ComponentName', () => {
|
||||
// WHY beforeEach with clearAllMocks?
|
||||
// - Ensures each test starts with clean slate
|
||||
// - Prevents mock call history from leaking between tests
|
||||
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
||||
// mockOpenState = false
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Rendering Tests (REQUIRED - Every component MUST have these)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Catches import errors, missing providers, and basic render issues
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
// Arrange - Setup data and mocks
|
||||
// const props = createMockProps()
|
||||
|
||||
// Act - Render the component
|
||||
// render(<ComponentName {...props} />)
|
||||
|
||||
// Assert - Verify expected output
|
||||
// Prefer getByRole for accessibility; it's what users "see"
|
||||
// expect(screen.getByRole('...')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with default props', () => {
|
||||
// WHY: Verifies component works without optional props
|
||||
// render(<ComponentName />)
|
||||
// expect(screen.getByText('...')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Props Tests (REQUIRED - Every component MUST test prop behavior)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Props are the component's API contract. Test them thoroughly.
|
||||
describe('Props', () => {
|
||||
it('should apply custom className', () => {
|
||||
// WHY: Common pattern in Dify - components should merge custom classes
|
||||
// render(<ComponentName className="custom-class" />)
|
||||
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
|
||||
})
|
||||
|
||||
it('should use default values for optional props', () => {
|
||||
// WHY: Verifies TypeScript defaults work at runtime
|
||||
// render(<ComponentName />)
|
||||
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// User Interactions (if component has event handlers - on*, handle*)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Event handlers are core functionality. Test from user's perspective.
|
||||
describe('User Interactions', () => {
|
||||
it('should call onClick when clicked', async () => {
|
||||
// WHY userEvent over fireEvent?
|
||||
// - userEvent simulates real user behavior (focus, hover, then click)
|
||||
// - fireEvent is lower-level, doesn't trigger all browser events
|
||||
// const user = userEvent.setup()
|
||||
// const handleClick = jest.fn()
|
||||
// render(<ComponentName onClick={handleClick} />)
|
||||
//
|
||||
// await user.click(screen.getByRole('button'))
|
||||
//
|
||||
// expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onChange when value changes', async () => {
|
||||
// const user = userEvent.setup()
|
||||
// const handleChange = jest.fn()
|
||||
// render(<ComponentName onChange={handleChange} />)
|
||||
//
|
||||
// await user.type(screen.getByRole('textbox'), 'new value')
|
||||
//
|
||||
// expect(handleChange).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// State Management (if component uses useState/useReducer)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Test state through observable UI changes, not internal state values
|
||||
describe('State Management', () => {
|
||||
it('should update state on interaction', async () => {
|
||||
// WHY test via UI, not state?
|
||||
// - State is implementation detail; UI is what users see
|
||||
// - If UI works correctly, state must be correct
|
||||
// const user = userEvent.setup()
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// // Initial state - verify what user sees
|
||||
// expect(screen.getByText('Initial')).toBeInTheDocument()
|
||||
//
|
||||
// // Trigger state change via user action
|
||||
// await user.click(screen.getByRole('button'))
|
||||
//
|
||||
// // New state - verify UI updated
|
||||
// expect(screen.getByText('Updated')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Async operations have 3 states users experience: loading, success, error
|
||||
describe('Async Operations', () => {
|
||||
it('should show loading state', () => {
|
||||
// WHY never-resolving promise?
|
||||
// - Keeps component in loading state for assertion
|
||||
// - Alternative: use fake timers
|
||||
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
// WHY waitFor?
|
||||
// - Component updates asynchronously after fetch resolves
|
||||
// - waitFor retries assertion until it passes or times out
|
||||
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(screen.getByText('Item 1')).toBeInTheDocument()
|
||||
// })
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
// })
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Real-world data is messy. Components must handle:
|
||||
// - Null/undefined from API failures or optional fields
|
||||
// - Empty arrays/strings from user clearing data
|
||||
// - Boundary values (0, MAX_INT, special characters)
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null value', () => {
|
||||
// WHY test null specifically?
|
||||
// - API might return null for missing data
|
||||
// - Prevents "Cannot read property of null" in production
|
||||
// render(<ComponentName value={null} />)
|
||||
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined value', () => {
|
||||
// WHY test undefined separately from null?
|
||||
// - TypeScript treats them differently
|
||||
// - Optional props are undefined, not null
|
||||
// render(<ComponentName value={undefined} />)
|
||||
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
// WHY: Empty state often needs special UI (e.g., "No items yet")
|
||||
// render(<ComponentName items={[]} />)
|
||||
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty string', () => {
|
||||
// WHY: Empty strings are truthy in JS but visually empty
|
||||
// render(<ComponentName text="" />)
|
||||
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Accessibility (optional but recommended for Dify's enterprise users)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Dify has enterprise customers who may require accessibility compliance
|
||||
describe('Accessibility', () => {
|
||||
it('should have accessible name', () => {
|
||||
// WHY getByRole with name?
|
||||
// - Tests that screen readers can identify the element
|
||||
// - Enforces proper labeling practices
|
||||
// render(<ComponentName label="Test Label" />)
|
||||
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should support keyboard navigation', async () => {
|
||||
// WHY: Some users can't use a mouse
|
||||
// const user = userEvent.setup()
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await user.tab()
|
||||
// expect(screen.getByRole('button')).toHaveFocus()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
/**
|
||||
* Test Template for Custom Hooks
|
||||
*
|
||||
* Instructions:
|
||||
* 1. Replace `useHookName` with your hook name
|
||||
* 2. Update import path
|
||||
* 3. Add/remove test sections based on hook features
|
||||
*/
|
||||
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
// import { useHookName } from './use-hook-name'
|
||||
|
||||
// ============================================================================
|
||||
// Mocks
|
||||
// ============================================================================
|
||||
|
||||
// API services (if hook fetches data)
|
||||
// jest.mock('@/service/api')
|
||||
// import * as api from '@/service/api'
|
||||
// const mockedApi = api as jest.Mocked<typeof api>
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
// Wrapper for hooks that need context
|
||||
// const createWrapper = (contextValue = {}) => {
|
||||
// return ({ children }: { children: React.ReactNode }) => (
|
||||
// <SomeContext.Provider value={contextValue}>
|
||||
// {children}
|
||||
// </SomeContext.Provider>
|
||||
// )
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('useHookName', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Initial State
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Initial State', () => {
|
||||
it('should return initial state', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// expect(result.current.value).toBe(initialValue)
|
||||
// expect(result.current.isLoading).toBe(false)
|
||||
})
|
||||
|
||||
it('should accept initial value from props', () => {
|
||||
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
|
||||
//
|
||||
// expect(result.current.value).toBe('custom')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// State Updates
|
||||
// --------------------------------------------------------------------------
|
||||
describe('State Updates', () => {
|
||||
it('should update value when setValue is called', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('new value')
|
||||
// })
|
||||
//
|
||||
// expect(result.current.value).toBe('new value')
|
||||
})
|
||||
|
||||
it('should reset to initial value', () => {
|
||||
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('changed')
|
||||
// })
|
||||
// expect(result.current.value).toBe('changed')
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.reset()
|
||||
// })
|
||||
// expect(result.current.value).toBe('initial')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Async Operations
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Async Operations', () => {
|
||||
it('should fetch data on mount', async () => {
|
||||
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
|
||||
//
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// // Initially loading
|
||||
// expect(result.current.isLoading).toBe(true)
|
||||
//
|
||||
// // Wait for data
|
||||
// await waitFor(() => {
|
||||
// expect(result.current.isLoading).toBe(false)
|
||||
// })
|
||||
//
|
||||
// expect(result.current.data).toEqual({ data: 'test' })
|
||||
})
|
||||
|
||||
it('should handle fetch error', async () => {
|
||||
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
//
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(result.current.error).toBeTruthy()
|
||||
// })
|
||||
//
|
||||
// expect(result.current.error?.message).toBe('Network error')
|
||||
})
|
||||
|
||||
it('should refetch when dependency changes', async () => {
|
||||
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
|
||||
//
|
||||
// const { result, rerender } = renderHook(
|
||||
// ({ id }) => useHookName(id),
|
||||
// { initialProps: { id: '1' } }
|
||||
// )
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
|
||||
// })
|
||||
//
|
||||
// rerender({ id: '2' })
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
|
||||
// })
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Side Effects
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Side Effects', () => {
|
||||
it('should call callback when value changes', () => {
|
||||
// const callback = jest.fn()
|
||||
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('new value')
|
||||
// })
|
||||
//
|
||||
// expect(callback).toHaveBeenCalledWith('new value')
|
||||
})
|
||||
|
||||
it('should cleanup on unmount', () => {
|
||||
// const cleanup = jest.fn()
|
||||
// jest.spyOn(window, 'addEventListener')
|
||||
// jest.spyOn(window, 'removeEventListener')
|
||||
//
|
||||
// const { unmount } = renderHook(() => useHookName())
|
||||
//
|
||||
// expect(window.addEventListener).toHaveBeenCalled()
|
||||
//
|
||||
// unmount()
|
||||
//
|
||||
// expect(window.removeEventListener).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null input', () => {
|
||||
// const { result } = renderHook(() => useHookName(null))
|
||||
//
|
||||
// expect(result.current.value).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle rapid updates', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('1')
|
||||
// result.current.setValue('2')
|
||||
// result.current.setValue('3')
|
||||
// })
|
||||
//
|
||||
// expect(result.current.value).toBe('3')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// With Context (if hook uses context)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('With Context', () => {
|
||||
it('should use context value', () => {
|
||||
// const wrapper = createWrapper({ someValue: 'context-value' })
|
||||
// const { result } = renderHook(() => useHookName(), { wrapper })
|
||||
//
|
||||
// expect(result.current.contextValue).toBe('context-value')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Test Template for Utility Functions
|
||||
*
|
||||
* Instructions:
|
||||
* 1. Replace `utilityFunction` with your function name
|
||||
* 2. Update import path
|
||||
* 3. Use test.each for data-driven tests
|
||||
*/
|
||||
|
||||
// import { utilityFunction } from './utility'
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('utilityFunction', () => {
|
||||
// --------------------------------------------------------------------------
|
||||
// Basic Functionality
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Basic Functionality', () => {
|
||||
it('should return expected result for valid input', () => {
|
||||
// expect(utilityFunction('input')).toBe('expected-output')
|
||||
})
|
||||
|
||||
it('should handle multiple arguments', () => {
|
||||
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Data-Driven Tests
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Input/Output Mapping', () => {
|
||||
test.each([
|
||||
// [input, expected]
|
||||
['input1', 'output1'],
|
||||
['input2', 'output2'],
|
||||
['input3', 'output3'],
|
||||
])('should return %s for input %s', (input, expected) => {
|
||||
// expect(utilityFunction(input)).toBe(expected)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty string', () => {
|
||||
// expect(utilityFunction('')).toBe('')
|
||||
})
|
||||
|
||||
it('should handle null', () => {
|
||||
// expect(utilityFunction(null)).toBe(null)
|
||||
// or
|
||||
// expect(() => utilityFunction(null)).toThrow()
|
||||
})
|
||||
|
||||
it('should handle undefined', () => {
|
||||
// expect(utilityFunction(undefined)).toBe(undefined)
|
||||
// or
|
||||
// expect(() => utilityFunction(undefined)).toThrow()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
// expect(utilityFunction([])).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle empty object', () => {
|
||||
// expect(utilityFunction({})).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Boundary Conditions
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Boundary Conditions', () => {
|
||||
it('should handle minimum value', () => {
|
||||
// expect(utilityFunction(0)).toBe(0)
|
||||
})
|
||||
|
||||
it('should handle maximum value', () => {
|
||||
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
|
||||
})
|
||||
|
||||
it('should handle negative numbers', () => {
|
||||
// expect(utilityFunction(-1)).toBe(...)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Type Coercion (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Type Handling', () => {
|
||||
it('should handle numeric string', () => {
|
||||
// expect(utilityFunction('123')).toBe(123)
|
||||
})
|
||||
|
||||
it('should handle boolean', () => {
|
||||
// expect(utilityFunction(true)).toBe(...)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Error Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Error Handling', () => {
|
||||
it('should throw for invalid input', () => {
|
||||
// expect(() => utilityFunction('invalid')).toThrow('Error message')
|
||||
})
|
||||
|
||||
it('should throw with specific error type', () => {
|
||||
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Complex Objects (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Object Handling', () => {
|
||||
it('should preserve object structure', () => {
|
||||
// const input = { a: 1, b: 2 }
|
||||
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
|
||||
})
|
||||
|
||||
it('should handle nested objects', () => {
|
||||
// const input = { nested: { deep: 'value' } }
|
||||
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
|
||||
})
|
||||
|
||||
it('should not mutate input', () => {
|
||||
// const input = { a: 1 }
|
||||
// const inputCopy = { ...input }
|
||||
// utilityFunction(input)
|
||||
// expect(input).toEqual(inputCopy)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Array Handling (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Array Handling', () => {
|
||||
it('should process all elements', () => {
|
||||
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
|
||||
})
|
||||
|
||||
it('should handle single element array', () => {
|
||||
// expect(utilityFunction([1])).toEqual([2])
|
||||
})
|
||||
|
||||
it('should preserve order', () => {
|
||||
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
[run]
|
||||
omit =
|
||||
api/tests/*
|
||||
api/migrations/*
|
||||
api/core/rag/datasource/vdb/*
|
||||
|
|
@ -9,6 +9,14 @@
|
|||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - MCP
|
||||
api/core/mcp/ @Nov1c444
|
||||
api/core/entities/mcp_provider.py @Nov1c444
|
||||
api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
api/controllers/mcp/ @Nov1c444
|
||||
api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
name: "✨ Refactor"
|
||||
description: Refactor existing code for improved readability and maintainability.
|
||||
title: "[Chore/Refactor] "
|
||||
labels:
|
||||
- refactor
|
||||
name: "✨ Refactor or Chore"
|
||||
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
|
||||
title: "[Refactor/Chore] "
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
|
|
@ -11,7 +9,7 @@ body:
|
|||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
|
|
@ -25,14 +23,14 @@ body:
|
|||
id: description
|
||||
attributes:
|
||||
label: Description
|
||||
placeholder: "Describe the refactor you are proposing."
|
||||
placeholder: "Describe the refactor or chore you are proposing."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation
|
||||
placeholder: "Explain why this refactor is necessary."
|
||||
placeholder: "Explain why this refactor or chore is necessary."
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
name: "👾 Tracker"
|
||||
description: For inner usages, please do not use this template.
|
||||
title: "[Tracker] "
|
||||
labels:
|
||||
- tracker
|
||||
body:
|
||||
- type: textarea
|
||||
id: content
|
||||
attributes:
|
||||
label: Blockers
|
||||
placeholder: "- [ ] ..."
|
||||
validations:
|
||||
required: true
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# Copilot Instructions
|
||||
|
||||
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
|
||||
|
||||
Key reminders:
|
||||
|
||||
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
|
||||
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
|
||||
- Target >95% line and branch coverage and 100% function/statement coverage.
|
||||
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
|
||||
|
||||
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
|
||||
|
|
@ -71,18 +71,18 @@ jobs:
|
|||
run: |
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Run Workflow
|
||||
run: uv run --project api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Run Tool
|
||||
run: uv run --project api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
- name: Run API Tests
|
||||
env:
|
||||
STORAGE_TYPE: opendal
|
||||
OPENDAL_SCHEME: fs
|
||||
OPENDAL_FS_ROOT: /tmp/dify-storage
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
uv run --project api pytest \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/workflow \
|
||||
api/tests/integration_tests/tools \
|
||||
api/tests/test_containers_integration_tests \
|
||||
api/tests/unit_tests
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
|
|
@ -94,4 +94,3 @@ jobs:
|
|||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
|
|
|||
|
|
@ -13,11 +13,12 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
|
|
@ -35,10 +36,11 @@ jobs:
|
|||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||
# ast-grep exits 1 if no matches are found; allow idempotent runs.
|
||||
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
|
||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||
cat > /tmp/optional-rule.yml << 'EOF'
|
||||
id: convert-optional-to-union
|
||||
|
|
@ -56,14 +58,15 @@ jobs:
|
|||
pattern: $T
|
||||
fix: $T | None
|
||||
EOF
|
||||
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
|
||||
uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
|
||||
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx mdformat .
|
||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
|
|
@ -84,7 +87,6 @@ jobs:
|
|||
|
||||
- name: oxlint
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpx oxlint --fix
|
||||
run: pnpm exec oxlint --config .oxlintrc.json --fix .
|
||||
|
||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
name: Semantic Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- synchronize
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check title
|
||||
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
@ -189,6 +189,7 @@ docker/volumes/matrixone/*
|
|||
docker/volumes/mysql/*
|
||||
docker/volumes/seekdb/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
docker/volumes/iris/*
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
docker/nginx/ssl/*
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
# Windsurf Testing Rules
|
||||
|
||||
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
|
||||
- Honor every requirement in that document when generating or accepting tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
|
|
@ -626,7 +626,17 @@ QUEUE_MONITOR_ALERT_EMAILS=
|
|||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
||||
# Swagger UI configuration
|
||||
SWAGGER_UI_ENABLED=true
|
||||
# SECURITY: Swagger UI is automatically disabled in PRODUCTION environment (DEPLOY_ENV=PRODUCTION)
|
||||
# to prevent API information disclosure.
|
||||
#
|
||||
# Behavior:
|
||||
# - DEPLOY_ENV=PRODUCTION + SWAGGER_UI_ENABLED not set -> Swagger DISABLED (secure default)
|
||||
# - DEPLOY_ENV=DEVELOPMENT/TESTING + SWAGGER_UI_ENABLED not set -> Swagger ENABLED
|
||||
# - SWAGGER_UI_ENABLED=true -> Swagger ENABLED (overrides environment check)
|
||||
# - SWAGGER_UI_ENABLED=false -> Swagger DISABLED (explicit disable)
|
||||
#
|
||||
# For development, you can uncomment below or set DEPLOY_ENV=DEVELOPMENT
|
||||
# SWAGGER_UI_ENABLED=false
|
||||
SWAGGER_UI_PATH=/swagger-ui.html
|
||||
|
||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_redis,
|
||||
ext_request_logging,
|
||||
ext_sentry,
|
||||
ext_session_factory,
|
||||
ext_set_secretkey,
|
||||
ext_storage,
|
||||
ext_timezone,
|
||||
|
|
@ -114,6 +115,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_commands,
|
||||
ext_otel,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
|
|
|||
|
|
@ -1221,9 +1221,19 @@ class WorkflowLogConfig(BaseSettings):
|
|||
|
||||
|
||||
class SwaggerUIConfig(BaseSettings):
|
||||
SWAGGER_UI_ENABLED: bool = Field(
|
||||
description="Whether to enable Swagger UI in api module",
|
||||
default=True,
|
||||
"""
|
||||
Configuration for Swagger UI documentation.
|
||||
|
||||
Security Note: Swagger UI is automatically disabled in PRODUCTION environment
|
||||
to prevent API information disclosure. Set SWAGGER_UI_ENABLED=true explicitly
|
||||
to enable in production if needed.
|
||||
"""
|
||||
|
||||
SWAGGER_UI_ENABLED: bool | None = Field(
|
||||
description="Whether to enable Swagger UI in api module. "
|
||||
"Automatically disabled in PRODUCTION environment for security. "
|
||||
"Set to true explicitly to enable in production.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SWAGGER_UI_PATH: str = Field(
|
||||
|
|
@ -1231,6 +1241,23 @@ class SwaggerUIConfig(BaseSettings):
|
|||
default="/swagger-ui.html",
|
||||
)
|
||||
|
||||
@property
|
||||
def swagger_ui_enabled(self) -> bool:
|
||||
"""
|
||||
Compute whether Swagger UI should be enabled.
|
||||
|
||||
If SWAGGER_UI_ENABLED is explicitly set, use that value.
|
||||
Otherwise, disable in PRODUCTION environment for security.
|
||||
"""
|
||||
if self.SWAGGER_UI_ENABLED is not None:
|
||||
return self.SWAGGER_UI_ENABLED
|
||||
|
||||
# Auto-disable in production environment
|
||||
import os
|
||||
|
||||
deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION")
|
||||
return deploy_env.upper() != "PRODUCTION"
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig
|
|||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
from .vdb.iris_config import IrisVectorConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.matrixone_config import MatrixoneConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
|
|
@ -336,6 +337,7 @@ class MiddlewareConfig(
|
|||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
IrisVectorConfig,
|
||||
MilvusConfig,
|
||||
AlibabaCloudMySQLConfig,
|
||||
MyScaleConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class IrisVectorConfig(BaseSettings):
|
||||
"""Configuration settings for IRIS vector database connection and pooling."""
|
||||
|
||||
IRIS_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the IRIS server.",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
|
||||
description="Port number for IRIS connection.",
|
||||
default=1972,
|
||||
)
|
||||
|
||||
IRIS_USER: str | None = Field(
|
||||
description="Username for IRIS authentication.",
|
||||
default="_SYSTEM",
|
||||
)
|
||||
|
||||
IRIS_PASSWORD: str | None = Field(
|
||||
description="Password for IRIS authentication.",
|
||||
default="Dify@1234",
|
||||
)
|
||||
|
||||
IRIS_SCHEMA: str | None = Field(
|
||||
description="Schema name for IRIS tables.",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
IRIS_DATABASE: str | None = Field(
|
||||
description="Database namespace for IRIS connection.",
|
||||
default="USER",
|
||||
)
|
||||
|
||||
IRIS_CONNECTION_URL: str | None = Field(
|
||||
description="Full connection URL for IRIS (overrides individual fields if provided).",
|
||||
default=None,
|
||||
)
|
||||
|
||||
IRIS_MIN_CONNECTION: PositiveInt = Field(
|
||||
description="Minimum number of connections in the pool.",
|
||||
default=1,
|
||||
)
|
||||
|
||||
IRIS_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Maximum number of connections in the pool.",
|
||||
default=3,
|
||||
)
|
||||
|
||||
IRIS_TEXT_INDEX: bool = Field(
|
||||
description="Enable full-text search index using %iFind.Index.Basic.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
|
||||
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
|
||||
default="en",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
values: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Validated configuration dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing or pool settings are invalid
|
||||
"""
|
||||
# Only validate required fields if IRIS is being used as the vector store
|
||||
# This allows the config to be loaded even when IRIS is not in use
|
||||
|
||||
# vector_store = os.environ.get("VECTOR_STORE", "")
|
||||
# We rely on Pydantic defaults for required fields if they are missing from env.
|
||||
# Strict existence check is removed to allow defaults to work.
|
||||
|
||||
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
|
||||
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
|
||||
if min_conn > max_conn:
|
||||
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
|
||||
|
||||
return values
|
||||
|
|
@ -20,6 +20,7 @@ language_timezone_mapping = {
|
|||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
"ar-TN": "Africa/Tunis",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
|
|
|||
|
|
@ -6,19 +6,20 @@ from flask import request
|
|||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ class InsertExploreAppListApi(Resource):
|
|||
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
|
|
@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource):
|
|||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
|
@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource):
|
|||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel):
|
|||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
|
|
@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource):
|
|||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
feedback.content = args.content
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
|
|
@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource):
|
|||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
content=args.content,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -230,6 +230,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
|||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
VectorType.IRIS,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
|
|
@ -422,7 +423,6 @@ class DatasetApi(Resource):
|
|||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_data = payload.model_dump(exclude_unset=True)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check embedding model setting
|
||||
if (
|
||||
|
|
@ -434,6 +434,7 @@ class DatasetApi(Resource):
|
|||
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
|
||||
)
|
||||
payload.is_multimodal = is_multimodal
|
||||
payload_data = payload.model_dump(exclude_unset=True)
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, payload.permission, payload.partial_member_list
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -52,10 +52,24 @@ class ChatMessagePayload(BaseModel):
|
|||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
conversation_id: UUID | None = None
|
||||
parent_message_id: UUID | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = Field(default="explore_app")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
|
||||
"""
|
||||
Accept blank IDs and validate UUID format when provided.
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from uuid import UUID
|
|||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -30,9 +30,16 @@ class ConversationListQuery(BaseModel):
|
|||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
|
|||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
tenant_id=tenant_id, provider_name=provider
|
||||
)
|
||||
else:
|
||||
model_type = args.model_type
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ class PluginDebuggingKeyApi(Resource):
|
|||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
page: int = Field(default=1)
|
||||
page_size: int = Field(default=256)
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
|
||||
|
||||
|
||||
reg(ParserList)
|
||||
|
|
@ -106,8 +106,8 @@ class ParserPluginIdentifierQuery(BaseModel):
|
|||
|
||||
|
||||
class ParserTasks(BaseModel):
|
||||
page: int
|
||||
page_size: int
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
|
||||
|
||||
|
||||
class ParserMarketplaceUpgrade(BaseModel):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -52,11 +52,23 @@ class ChatRequestPayload(BaseModel):
|
|||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUID | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
|
|
@ -37,9 +37,16 @@ class ConversationListQuery(BaseModel):
|
|||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str = Field(description="New conversation name")
|
||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str):
|
|||
if response:
|
||||
break
|
||||
if not response:
|
||||
logger.error("Endpoint not found for {endpoint_id}")
|
||||
logger.info("Endpoint not found for %s", endpoint_id)
|
||||
return jsonify({"error": "Endpoint not found"}), 404
|
||||
return response
|
||||
except ValueError as e:
|
||||
|
|
|
|||
|
|
@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
|||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
|
|
@ -73,7 +72,7 @@ from extensions.ext_database import db
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
|
|
@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
self._save_message(session=session)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
|
|
@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -772,13 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
if self._conversation_name_generate_thread:
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
|
|
@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
|
||||
# Extract model provider and model_id from workflow node executions for tracing
|
||||
if message.workflow_run_id:
|
||||
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
|
||||
if model_info:
|
||||
message.model_provider = model_info.get("provider")
|
||||
message.model_id = model_info.get("model")
|
||||
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
|
|
@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Trigger MESSAGE_TRACE for tracing integrations
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Extract model provider and model_id from workflow node executions.
|
||||
Returns dict with 'provider' and 'model' keys, or None if not found.
|
||||
"""
|
||||
try:
|
||||
# Query workflow node executions for LLM or Agent nodes
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
|
||||
.order_by(WorkflowNodeExecutionModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
node_execution = session.scalar(stmt)
|
||||
|
||||
if not node_execution:
|
||||
return None
|
||||
|
||||
# Try to extract from execution_metadata for agent nodes
|
||||
if node_execution.execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(node_execution.execution_metadata)
|
||||
agent_log = metadata.get("agent_log", [])
|
||||
# Look for the first agent thought with provider info
|
||||
for log_entry in agent_log:
|
||||
entry_metadata = log_entry.get("metadata", {})
|
||||
provider_str = entry_metadata.get("provider")
|
||||
if provider_str:
|
||||
# Parse format like "langgenius/deepseek/deepseek"
|
||||
parts = provider_str.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {"provider": parts[1], "model": parts[2]}
|
||||
elif len(parts) == 2:
|
||||
return {"provider": parts[0], "model": parts[1]}
|
||||
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
||||
logger.debug("Failed to parse execution_metadata: %s", e)
|
||||
|
||||
# Try to extract from process_data for llm nodes
|
||||
if node_execution.process_data:
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider and model:
|
||||
return {"provider": provider, "model": model}
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.debug("Failed to parse process_data: %s", e)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract model info from workflow: %s", e)
|
||||
return None
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
|
|
|||
|
|
@ -40,9 +40,6 @@ class EasyUITaskState(TaskState):
|
|||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class WorkflowTaskState(TaskState):
|
||||
|
|
|
|||
|
|
@ -332,12 +332,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# Track streaming response times
|
||||
if self._task_state.first_token_time is None:
|
||||
self._task_state.first_token_time = time.perf_counter()
|
||||
self._task_state.is_streaming_response = True
|
||||
self._task_state.last_token_time = time.perf_counter()
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
|
|
@ -404,18 +398,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
# Update metadata with the complete usage info
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
_session_maker: sessionmaker | None = None
|
||||
|
||||
|
||||
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
"""Configure the global session factory"""
|
||||
global _session_maker
|
||||
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||
|
||||
|
||||
def get_session_maker() -> sessionmaker:
|
||||
if _session_maker is None:
|
||||
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
|
||||
return _session_maker
|
||||
|
||||
|
||||
def create_session() -> Session:
|
||||
return get_session_maker()()
|
||||
|
||||
|
||||
# Class wrapper for convenience
|
||||
class SessionFactory:
|
||||
@staticmethod
|
||||
def configure(engine: Engine, expire_on_commit: bool = False):
|
||||
configure_session_factory(engine, expire_on_commit)
|
||||
|
||||
@staticmethod
|
||||
def get_session_maker() -> sessionmaker:
|
||||
return get_session_maker()
|
||||
|
||||
@staticmethod
|
||||
def create_session() -> Session:
|
||||
return create_session()
|
||||
|
||||
|
||||
session_factory = SessionFactory()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
|
|
@ -20,7 +20,7 @@ class IndexingEstimate(BaseModel):
|
|||
class PipelineDataset(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
description: str | None = Field(default="", description="knowledge dataset description")
|
||||
chunk_structure: str
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel):
|
|||
return None
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
"""OAuth tokens if available"""
|
||||
"""Retrieve OAuth tokens if authentication is complete.
|
||||
|
||||
Returns:
|
||||
OAuthTokens if the provider has been authenticated, None otherwise.
|
||||
"""
|
||||
if not self.credentials:
|
||||
return None
|
||||
credentials = self.decrypt_credentials()
|
||||
access_token = credentials.get("access_token", "")
|
||||
# Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header.
|
||||
# Note: We don't check for whitespace-only strings here because:
|
||||
# 1. OAuth servers don't return whitespace-only access tokens in practice
|
||||
# 2. Even if they did, the server would return 401, triggering the OAuth flow correctly
|
||||
if not access_token:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
access_token=access_token,
|
||||
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
|
|
|
|||
|
|
@ -72,15 +72,22 @@ class LLMGenerator:
|
|||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
if answer is None:
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
result_dict = json.loads(answer)
|
||||
except json.JSONDecodeError:
|
||||
logger.exception("Failed to generate name after answer, use query instead")
|
||||
result_dict = json_repair.loads(answer)
|
||||
|
||||
if not isinstance(result_dict, dict):
|
||||
answer = query
|
||||
else:
|
||||
output = result_dict.get("Your Output")
|
||||
if isinstance(output, str) and output.strip():
|
||||
answer = output.strip()
|
||||
else:
|
||||
answer = query
|
||||
|
||||
name = answer.strip()
|
||||
|
||||
if len(name) > 75:
|
||||
|
|
@ -554,11 +561,16 @@ class LLMGenerator:
|
|||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = cast(str, response.message.content)
|
||||
generated_raw = response.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
|
||||
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,13 @@ from datetime import datetime, timedelta
|
|||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
|
||||
from openinference.semconv.trace import (
|
||||
MessageAttributes,
|
||||
OpenInferenceMimeTypeValues,
|
||||
OpenInferenceSpanKindValues,
|
||||
SpanAttributes,
|
||||
ToolCallAttributes,
|
||||
)
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
|
|
@ -95,14 +101,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
|
|||
|
||||
|
||||
def datetime_to_nanos(dt: datetime | None) -> int:
|
||||
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
|
||||
"""Convert datetime to nanoseconds since epoch for Arize/Phoenix."""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def error_to_string(error: Exception | str | None) -> str:
|
||||
"""Convert an error to a string with traceback information."""
|
||||
"""Convert an error to a string with traceback information for Arize/Phoenix."""
|
||||
error_message = "Empty Stack Trace"
|
||||
if error:
|
||||
if isinstance(error, Exception):
|
||||
|
|
@ -114,7 +120,7 @@ def error_to_string(error: Exception | str | None) -> str:
|
|||
|
||||
|
||||
def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
"""Set the status of the current span based on the presence of an error."""
|
||||
"""Set the status of the current span based on the presence of an error for Arize/Phoenix."""
|
||||
if error:
|
||||
error_string = error_to_string(error)
|
||||
current_span.set_status(Status(StatusCode.ERROR, error_string))
|
||||
|
|
@ -138,10 +144,17 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
|
|||
|
||||
|
||||
def safe_json_dumps(obj: Any) -> str:
|
||||
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
|
||||
"""A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix."""
|
||||
return json.dumps(obj, default=str, ensure_ascii=False)
|
||||
|
||||
|
||||
def wrap_span_metadata(metadata, **kwargs):
|
||||
"""Add common metatada to all trace entity types for Arize/Phoenix."""
|
||||
metadata["created_from"] = "Dify"
|
||||
metadata.update(kwargs)
|
||||
return metadata
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -183,16 +196,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
workflow_metadata = {
|
||||
"workflow_run_id": trace_info.workflow_run_id or "",
|
||||
"message_id": trace_info.message_id or "",
|
||||
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
|
||||
"status": trace_info.workflow_run_status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.workflow_run_status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="workflow",
|
||||
conversation_id=trace_info.conversation_id or "",
|
||||
workflow_app_log_id=trace_info.workflow_app_log_id or "",
|
||||
workflow_id=trace_info.workflow_id or "",
|
||||
tenant_id=trace_info.tenant_id or "",
|
||||
workflow_run_id=trace_info.workflow_run_id or "",
|
||||
workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0,
|
||||
workflow_run_version=trace_info.workflow_run_version or "",
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
file_list=safe_json_dumps(file_list),
|
||||
query=trace_info.query or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -201,10 +225,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
|
|
@ -257,6 +283,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
"app_id": app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"status_message": node_execution.error or "",
|
||||
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
|
||||
}
|
||||
)
|
||||
|
|
@ -290,11 +317,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
|
|
@ -339,30 +366,37 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.")
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_metadata = {
|
||||
"message_id": trace_info.message_id or "",
|
||||
"conversation_mode": str(trace_info.conversation_mode or ""),
|
||||
"user_id": trace_info.message_data.from_account_id or "",
|
||||
"file_list": json.dumps(file_list),
|
||||
"status": trace_info.message_data.status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
"prompt_tokens": trace_info.message_tokens or 0,
|
||||
"completion_tokens": trace_info.answer_tokens or 0,
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
message_metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="message",
|
||||
conversation_model=trace_info.conversation_model or "",
|
||||
message_tokens=trace_info.message_tokens or 0,
|
||||
answer_tokens=trace_info.answer_tokens or 0,
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
conversation_mode=trace_info.conversation_mode or "",
|
||||
gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0,
|
||||
llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0,
|
||||
is_streaming_request=trace_info.is_streaming_request or False,
|
||||
user_id=trace_info.message_data.from_account_id or "",
|
||||
file_list=safe_json_dumps(file_list),
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
|
|
@ -370,14 +404,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
message_metadata["end_user_id"] = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_data.session_id
|
||||
|
||||
attributes = {
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
|
||||
}
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
|
|
@ -393,8 +429,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
try:
|
||||
# Convert outputs to string based on type
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
outputs_str = safe_json_dumps(trace_info.outputs)
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
|
|
@ -402,10 +440,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
llm_attributes = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: outputs_str,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
|
||||
}
|
||||
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
|
|
@ -449,16 +489,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "moderation",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.message_data.error or "",
|
||||
level="ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
trace_entity_type="moderation",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -467,18 +511,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(
|
||||
{
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"action": trace_info.action,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
"query": trace_info.query,
|
||||
}
|
||||
),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
|
|
@ -494,22 +539,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.")
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "suggested_question",
|
||||
"status": trace_info.status,
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
"ls_provider": trace_info.model_provider or "",
|
||||
"ls_model_name": trace_info.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.status or "",
|
||||
status_message=trace_info.status_message or "",
|
||||
level=trace_info.level or "",
|
||||
trace_entity_type="suggested_question",
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
from_account_id=trace_info.from_account_id or "",
|
||||
agent_based=trace_info.agent_based or False,
|
||||
from_source=trace_info.from_source or "",
|
||||
model_provider=trace_info.model_provider or "",
|
||||
model_id=trace_info.model_id or "",
|
||||
workflow_run_id=trace_info.workflow_run_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -518,10 +569,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=root_span_context,
|
||||
|
|
@ -537,21 +590,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.")
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "dataset_retrieval",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="dataset_retrieval",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -560,20 +615,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
"start_time": start_time.isoformat() if start_time else "",
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
if trace_info.error:
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
|
|
@ -584,30 +639,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="tool",
|
||||
tool_config=safe_json_dumps(trace_info.tool_config),
|
||||
time_cost=trace_info.time_cost or 0,
|
||||
file_url=trace_info.file_url or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
if isinstance(trace_info.tool_parameters, dict)
|
||||
else str(trace_info.tool_parameters)
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=trace_info.tool_name,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.TOOL_NAME: trace_info.tool_name,
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
|
|
@ -623,16 +682,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"project_name": self.project,
|
||||
"message_id": trace_info.message_id,
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.message_data.error or "",
|
||||
level="ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
trace_entity_type="generate_name",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
conversation_id=trace_info.conversation_id or "",
|
||||
tenant_id=trace_info.tenant_id,
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -641,13 +706,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
|
|
@ -688,32 +753,85 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
"""Build a redirect URL that forwards the user to the correct project for Arize/Phoenix."""
|
||||
try:
|
||||
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
|
||||
return "https://app.arize.com/"
|
||||
else:
|
||||
return f"{self.arize_phoenix_config.endpoint}/projects/"
|
||||
project_name = self.arize_phoenix_config.project
|
||||
endpoint = self.arize_phoenix_config.endpoint.rstrip("/")
|
||||
|
||||
# Arize
|
||||
if isinstance(self.arize_phoenix_config, ArizeConfig):
|
||||
return f"https://app.arize.com/?redirect_project_name={project_name}"
|
||||
|
||||
# Phoenix
|
||||
return f"{endpoint}/projects/?redirect_project_name={project_name}"
|
||||
|
||||
except Exception as e:
|
||||
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
"""Helper method to construct LLM attributes with passed prompts."""
|
||||
attributes = {}
|
||||
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
|
||||
attributes: dict[str, str] = {}
|
||||
|
||||
def set_attribute(path: str, value: object) -> None:
|
||||
"""Store an attribute safely as a string."""
|
||||
if value is None:
|
||||
return
|
||||
try:
|
||||
if isinstance(value, (dict, list)):
|
||||
value = safe_json_dumps(value)
|
||||
attributes[path] = str(value)
|
||||
except Exception:
|
||||
attributes[path] = str(value)
|
||||
|
||||
def set_message_attribute(message_index: int, key: str, value: object) -> None:
|
||||
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
|
||||
set_attribute(path, value)
|
||||
|
||||
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
|
||||
"""Extract and assign tool call details safely."""
|
||||
if not tool_call:
|
||||
return
|
||||
|
||||
def safe_get(obj, key, default=None):
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(key, default)
|
||||
return getattr(obj, key, default)
|
||||
|
||||
function_obj = safe_get(tool_call, "function", {})
|
||||
function_name = safe_get(function_obj, "name", "")
|
||||
function_args = safe_get(function_obj, "arguments", {})
|
||||
call_id = safe_get(tool_call, "id", "")
|
||||
|
||||
base_path = (
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}."
|
||||
f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}"
|
||||
)
|
||||
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name)
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args)
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
|
||||
|
||||
# Handle list of messages
|
||||
if isinstance(prompts, list):
|
||||
for i, msg in enumerate(prompts):
|
||||
if isinstance(msg, dict):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(prompts, dict):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(prompts, str):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
for message_index, message in enumerate(prompts):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
|
||||
role = message.get("role", "user")
|
||||
content = message.get("text") or message.get("content") or ""
|
||||
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
|
||||
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_index, tool_call in enumerate(tool_calls):
|
||||
set_tool_call_attributes(message_index, tool_index, tool_call)
|
||||
|
||||
# Handle single dict or plain string prompt
|
||||
elif isinstance(prompts, (dict, str)):
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
|
||||
|
||||
return attributes
|
||||
|
|
|
|||
|
|
@ -222,59 +222,6 @@ class TencentSpanBuilder:
|
|||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_message_llm_span(
|
||||
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
|
||||
) -> SpanData:
|
||||
"""Build LLM span for message traces with detailed LLM attributes."""
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
# Extract model information from `metadata`` or `message_data`
|
||||
trace_metadata = trace_info.metadata or {}
|
||||
message_data = trace_info.message_data or {}
|
||||
|
||||
model_provider = trace_metadata.get("ls_provider") or (
|
||||
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
model_name = trace_metadata.get("ls_model_name") or (
|
||||
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
|
||||
inputs_str = str(trace_info.inputs or "")
|
||||
outputs_str = str(trace_info.outputs or "")
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: str(model_name),
|
||||
GEN_AI_PROVIDER: str(model_provider),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
|
||||
GEN_AI_PROMPT: inputs_str,
|
||||
GEN_AI_COMPLETION: outputs_str,
|
||||
INPUT_VALUE: inputs_str,
|
||||
OUTPUT_VALUE: outputs_str,
|
||||
}
|
||||
|
||||
if trace_info.is_streaming_request:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build tool span."""
|
||||
|
|
|
|||
|
|
@ -107,12 +107,8 @@ class TencentDataTrace(BaseTraceInstance):
|
|||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
# Add LLM child span with detailed attributes
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
|
||||
self.trace_client.add_span(llm_span)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
self._record_message_llm_metrics(trace_info)
|
||||
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ class RetrievalService:
|
|||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
with Session(db.engine) as session:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Process documents
|
||||
for document in documents:
|
||||
segment_id = None
|
||||
|
|
@ -395,7 +395,7 @@ class RetrievalService:
|
|||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attchment_info"]
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
else:
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
|
|
@ -417,13 +417,6 @@ class RetrievalService:
|
|||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
.options(
|
||||
load_only(
|
||||
DocumentSegment.id,
|
||||
DocumentSegment.content,
|
||||
DocumentSegment.answer,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -458,12 +451,21 @@ class RetrievalService:
|
|||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
if segment.id in segment_child_map:
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id].append(attachment_info)
|
||||
if segment.id in segment_file_map:
|
||||
segment_file_map[segment.id].append(attachment_info)
|
||||
else:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
# Handle normal documents
|
||||
segment = None
|
||||
|
|
@ -475,7 +477,7 @@ class RetrievalService:
|
|||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attchment_info"]
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
|
|
@ -483,7 +485,7 @@ class RetrievalService:
|
|||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
if segment:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
|
|
@ -496,7 +498,7 @@ class RetrievalService:
|
|||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
|
@ -684,7 +686,7 @@ class RetrievalService:
|
|||
.first()
|
||||
)
|
||||
if attachment_binding:
|
||||
attchment_info = {
|
||||
attachment_info = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
|
|
@ -692,5 +694,5 @@ class RetrievalService:
|
|||
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||
"size": upload_file.size,
|
||||
}
|
||||
return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id}
|
||||
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,407 @@
|
|||
"""InterSystems IRIS vector database implementation for Dify.
|
||||
|
||||
This module provides vector storage and retrieval using IRIS native VECTOR type
|
||||
with HNSW indexing for efficient similarity search.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.iris_config import IrisVectorConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import iris
|
||||
else:
|
||||
try:
|
||||
import iris
|
||||
except ImportError:
|
||||
iris = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Singleton connection pool to minimize IRIS license usage
|
||||
_pool_lock = threading.Lock()
|
||||
_pool_instance: IrisConnectionPool | None = None
|
||||
|
||||
|
||||
def get_iris_pool(config: IrisVectorConfig) -> IrisConnectionPool:
|
||||
"""Get or create the global IRIS connection pool (singleton pattern)."""
|
||||
global _pool_instance # pylint: disable=global-statement
|
||||
with _pool_lock:
|
||||
if _pool_instance is None:
|
||||
logger.info("Initializing IRIS connection pool")
|
||||
_pool_instance = IrisConnectionPool(config)
|
||||
return _pool_instance
|
||||
|
||||
|
||||
class IrisConnectionPool:
|
||||
"""Thread-safe connection pool for IRIS database."""
|
||||
|
||||
def __init__(self, config: IrisVectorConfig) -> None:
|
||||
self.config = config
|
||||
self._pool: list[Any] = []
|
||||
self._lock = threading.Lock()
|
||||
self._min_size = config.IRIS_MIN_CONNECTION
|
||||
self._max_size = config.IRIS_MAX_CONNECTION
|
||||
self._in_use = 0
|
||||
self._schemas_initialized: set[str] = set() # Cache for initialized schemas
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self) -> None:
|
||||
for _ in range(self._min_size):
|
||||
self._pool.append(self._create_connection())
|
||||
|
||||
def _create_connection(self) -> Any:
|
||||
return iris.connect(
|
||||
hostname=self.config.IRIS_HOST,
|
||||
port=self.config.IRIS_SUPER_SERVER_PORT,
|
||||
namespace=self.config.IRIS_DATABASE,
|
||||
username=self.config.IRIS_USER,
|
||||
password=self.config.IRIS_PASSWORD,
|
||||
)
|
||||
|
||||
def get_connection(self) -> Any:
|
||||
"""Get a connection from pool or create new if available."""
|
||||
with self._lock:
|
||||
if self._pool:
|
||||
conn = self._pool.pop()
|
||||
self._in_use += 1
|
||||
return conn
|
||||
if self._in_use < self._max_size:
|
||||
conn = self._create_connection()
|
||||
self._in_use += 1
|
||||
return conn
|
||||
raise RuntimeError("Connection pool exhausted")
|
||||
|
||||
def return_connection(self, conn: Any) -> None:
|
||||
"""Return connection to pool after validating it."""
|
||||
if not conn:
|
||||
return
|
||||
|
||||
# Validate connection health
|
||||
is_valid = False
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.close()
|
||||
is_valid = True
|
||||
except (OSError, RuntimeError) as e:
|
||||
logger.debug("Connection validation failed: %s", e)
|
||||
try:
|
||||
conn.close()
|
||||
except (OSError, RuntimeError):
|
||||
pass
|
||||
|
||||
with self._lock:
|
||||
self._pool.append(conn if is_valid else self._create_connection())
|
||||
self._in_use -= 1
|
||||
|
||||
def ensure_schema_exists(self, schema: str) -> None:
|
||||
"""Ensure schema exists in IRIS database.
|
||||
|
||||
This method is idempotent and thread-safe. It uses a memory cache to avoid
|
||||
redundant database queries for already-verified schemas.
|
||||
|
||||
Args:
|
||||
schema: Schema name to ensure exists
|
||||
|
||||
Raises:
|
||||
Exception: If schema creation fails
|
||||
"""
|
||||
# Fast path: check cache first (no lock needed for read-only set lookup)
|
||||
if schema in self._schemas_initialized:
|
||||
return
|
||||
|
||||
# Slow path: acquire lock and check again (double-checked locking)
|
||||
with self._lock:
|
||||
if schema in self._schemas_initialized:
|
||||
return
|
||||
|
||||
# Get a connection to check/create schema
|
||||
conn = self._pool[0] if self._pool else self._create_connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
# Check if schema exists using INFORMATION_SCHEMA
|
||||
check_sql = """
|
||||
SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA
|
||||
WHERE SCHEMA_NAME = ?
|
||||
"""
|
||||
cursor.execute(check_sql, (schema,)) # Must be tuple or list
|
||||
exists = cursor.fetchone()[0] > 0
|
||||
|
||||
if not exists:
|
||||
# Schema doesn't exist, create it
|
||||
cursor.execute(f"CREATE SCHEMA {schema}")
|
||||
conn.commit()
|
||||
logger.info("Created schema: %s", schema)
|
||||
else:
|
||||
logger.debug("Schema already exists: %s", schema)
|
||||
|
||||
# Add to cache to skip future checks
|
||||
self._schemas_initialized.add(schema)
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.exception("Failed to ensure schema %s exists", schema)
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def close_all(self) -> None:
|
||||
"""Close all connections (application shutdown only)."""
|
||||
with self._lock:
|
||||
for conn in self._pool:
|
||||
try:
|
||||
conn.close()
|
||||
except (OSError, RuntimeError):
|
||||
pass
|
||||
self._pool.clear()
|
||||
self._in_use = 0
|
||||
self._schemas_initialized.clear()
|
||||
|
||||
|
||||
class IrisVector(BaseVector):
|
||||
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
|
||||
|
||||
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
|
||||
super().__init__(collection_name)
|
||||
self.config = config
|
||||
self.table_name = f"embedding_{collection_name}".upper()
|
||||
self.schema = config.IRIS_SCHEMA or "dify"
|
||||
self.pool = get_iris_pool(config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.IRIS
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
"""Context manager for database cursor with connection pooling."""
|
||||
conn = self.pool.get_connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
self.pool.return_connection(conn)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]:
|
||||
"""Add documents with embeddings to the collection."""
|
||||
added_ids = []
|
||||
with self._get_cursor() as cursor:
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4())
|
||||
metadata = json.dumps(doc.metadata) if doc.metadata else "{}"
|
||||
embedding_str = json.dumps(embeddings[i])
|
||||
|
||||
sql = f"INSERT INTO {self.schema}.{self.table_name} (id, text, meta, embedding) VALUES (?, ?, ?, ?)"
|
||||
cursor.execute(sql, (doc_id, doc.page_content, metadata, embedding_str))
|
||||
added_ids.append(doc_id)
|
||||
|
||||
return added_ids
|
||||
|
||||
def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin
|
||||
try:
|
||||
with self._get_cursor() as cursor:
|
||||
sql = f"SELECT 1 FROM {self.schema}.{self.table_name} WHERE id = ?"
|
||||
cursor.execute(sql, (id,))
|
||||
return cursor.fetchone() is not None
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
placeholders = ",".join(["?" for _ in ids])
|
||||
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})"
|
||||
cursor.execute(sql, ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
"""Delete documents by metadata field (JSON LIKE pattern matching)."""
|
||||
with self._get_cursor() as cursor:
|
||||
pattern = f'%"{key}": "{value}"%'
|
||||
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?"
|
||||
cursor.execute(sql, (pattern,))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search similar documents using VECTOR_COSINE with HNSW index."""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
embedding_str = json.dumps(query_vector)
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta, VECTOR_COSINE(embedding, ?) as score
|
||||
FROM {self.schema}.{self.table_name}
|
||||
ORDER BY score DESC
|
||||
"""
|
||||
cursor.execute(sql, (embedding_str,))
|
||||
|
||||
docs = []
|
||||
for row in cursor.fetchall():
|
||||
if len(row) >= 4:
|
||||
text, meta_str, score = row[1], row[2], float(row[3])
|
||||
if score >= score_threshold:
|
||||
metadata = json.loads(meta_str) if meta_str else {}
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search documents by full-text using iFind index or fallback to LIKE search."""
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
if self.config.IRIS_TEXT_INDEX:
|
||||
# Use iFind full-text search with index
|
||||
text_index_name = f"idx_{self.table_name}_text"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE %ID %FIND search_index({text_index_name}, ?)
|
||||
"""
|
||||
cursor.execute(sql, (query,))
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
query_pattern = f"%{query}%"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ?
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
docs = []
|
||||
for row in cursor.fetchall():
|
||||
if len(row) >= 3:
|
||||
metadata = json.loads(row[2]) if row[2] else {}
|
||||
docs.append(Document(page_content=row[1], metadata=metadata))
|
||||
|
||||
if not docs:
|
||||
logger.info("Full-text search for '%s' returned no results", query)
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the entire collection (drop table - permanent)."""
|
||||
with self._get_cursor() as cursor:
|
||||
sql = f"DROP TABLE {self.schema}.{self.table_name}"
|
||||
cursor.execute(sql)
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
"""Create table with VECTOR column and HNSW index.
|
||||
|
||||
Uses Redis lock to prevent concurrent creation attempts across multiple
|
||||
API server instances (api, worker, worker_beat).
|
||||
"""
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
|
||||
with redis_client.lock(lock_name, timeout=20): # pylint: disable=not-context-manager
|
||||
if redis_client.get(cache_key):
|
||||
return
|
||||
|
||||
# Ensure schema exists (idempotent, cached after first call)
|
||||
self.pool.ensure_schema_exists(self.schema)
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
# Create table with VECTOR column
|
||||
sql = f"""
|
||||
CREATE TABLE {self.schema}.{self.table_name} (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
text CLOB,
|
||||
meta CLOB,
|
||||
embedding VECTOR(DOUBLE, {dimension})
|
||||
)
|
||||
"""
|
||||
logger.info("Creating table: %s.%s", self.schema, self.table_name)
|
||||
cursor.execute(sql)
|
||||
|
||||
# Create HNSW index for vector similarity search
|
||||
index_name = f"idx_{self.table_name}_embedding"
|
||||
sql_index = (
|
||||
f"CREATE INDEX {index_name} ON {self.schema}.{self.table_name} "
|
||||
"(embedding) AS HNSW(Distance='Cosine')"
|
||||
)
|
||||
logger.info("Creating HNSW index: %s", index_name)
|
||||
cursor.execute(sql_index)
|
||||
logger.info("HNSW index created successfully: %s", index_name)
|
||||
|
||||
# Create full-text search index if enabled
|
||||
logger.info(
|
||||
"IRIS_TEXT_INDEX config value: %s (type: %s)",
|
||||
self.config.IRIS_TEXT_INDEX,
|
||||
type(self.config.IRIS_TEXT_INDEX),
|
||||
)
|
||||
if self.config.IRIS_TEXT_INDEX:
|
||||
text_index_name = f"idx_{self.table_name}_text"
|
||||
language = self.config.IRIS_TEXT_INDEX_LANGUAGE
|
||||
# Fixed: Removed extra parentheses and corrected syntax
|
||||
sql_text_index = f"""
|
||||
CREATE INDEX {text_index_name} ON {self.schema}.{self.table_name} (text)
|
||||
AS %iFind.Index.Basic
|
||||
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
|
||||
"""
|
||||
logger.info("Creating text index: %s with language: %s", text_index_name, language)
|
||||
logger.info("SQL for text index: %s", sql_text_index)
|
||||
cursor.execute(sql_text_index)
|
||||
logger.info("Text index created successfully: %s", text_index_name)
|
||||
else:
|
||||
logger.warning("Text index creation skipped - IRIS_TEXT_INDEX is disabled")
|
||||
|
||||
redis_client.set(cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class IrisVectorFactory(AbstractVectorFactory):
|
||||
"""Factory for creating IrisVector instances."""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = self.gen_index_struct_dict(VectorType.IRIS, collection_name)
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
return IrisVector(
|
||||
collection_name=collection_name,
|
||||
config=IrisVectorConfig(
|
||||
IRIS_HOST=dify_config.IRIS_HOST,
|
||||
IRIS_SUPER_SERVER_PORT=dify_config.IRIS_SUPER_SERVER_PORT,
|
||||
IRIS_USER=dify_config.IRIS_USER,
|
||||
IRIS_PASSWORD=dify_config.IRIS_PASSWORD,
|
||||
IRIS_DATABASE=dify_config.IRIS_DATABASE,
|
||||
IRIS_SCHEMA=dify_config.IRIS_SCHEMA,
|
||||
IRIS_CONNECTION_URL=dify_config.IRIS_CONNECTION_URL,
|
||||
IRIS_MIN_CONNECTION=dify_config.IRIS_MIN_CONNECTION,
|
||||
IRIS_MAX_CONNECTION=dify_config.IRIS_MAX_CONNECTION,
|
||||
IRIS_TEXT_INDEX=dify_config.IRIS_TEXT_INDEX,
|
||||
IRIS_TEXT_INDEX_LANGUAGE=dify_config.IRIS_TEXT_INDEX_LANGUAGE,
|
||||
),
|
||||
)
|
||||
|
|
@ -188,6 +188,10 @@ class Vector:
|
|||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case VectorType.IRIS:
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
|
||||
|
||||
return IrisVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
|
|
|||
|
|
@ -32,3 +32,4 @@ class VectorType(StrEnum):
|
|||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
IRIS = "iris"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
|
|
@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
|||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class Candidate(TypedDict):
|
||||
idx: int
|
||||
count: int
|
||||
map: dict[int, str]
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
|
|
@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor):
|
|||
file_extension = os.path.splitext(self._file_path)[-1].lower()
|
||||
|
||||
if file_extension == ".xlsx":
|
||||
wb = load_workbook(self._file_path, data_only=True)
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
data = sheet.values
|
||||
cols = next(data, None)
|
||||
if cols is None:
|
||||
continue
|
||||
df = pd.DataFrame(data, columns=cols)
|
||||
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for index, row in df.iterrows():
|
||||
page_content = []
|
||||
for col_index, (k, v) in enumerate(row.items()):
|
||||
if pd.notna(v):
|
||||
cell = sheet.cell(
|
||||
row=cast(int, index) + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
page_content.append(f'"{k}":"{value}"')
|
||||
else:
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
wb = load_workbook(self._file_path, read_only=True, data_only=True)
|
||||
try:
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet)
|
||||
if not column_map:
|
||||
continue
|
||||
start_row = header_row_idx + 1
|
||||
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
|
||||
if all(cell.value is None for cell in row):
|
||||
continue
|
||||
page_content = []
|
||||
for col_idx, cell in enumerate(row):
|
||||
value = cell.value
|
||||
if col_idx in column_map:
|
||||
col_name = column_map[col_idx]
|
||||
if hasattr(cell, "hyperlink") and cell.hyperlink:
|
||||
target = getattr(cell.hyperlink, "target", None)
|
||||
if target:
|
||||
value = f"[{value}]({target})"
|
||||
if value is None:
|
||||
value = ""
|
||||
elif not isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.strip().replace('"', '\\"')
|
||||
page_content.append(f'"{col_name}":"{value}"')
|
||||
if page_content:
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
finally:
|
||||
wb.close()
|
||||
|
||||
elif file_extension == ".xls":
|
||||
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
|
||||
|
|
@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor):
|
|||
df = excel_file.parse(sheet_name=excel_sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
for _, series_row in df.iterrows():
|
||||
page_content = []
|
||||
for k, v in row.items():
|
||||
for k, v in series_row.items():
|
||||
if pd.notna(v):
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(
|
||||
|
|
@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor):
|
|||
raise ValueError(f"Unsupported file extension: {file_extension}")
|
||||
|
||||
return documents
|
||||
|
||||
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
|
||||
"""
|
||||
Scan first N rows to find the most likely header row.
|
||||
Returns:
|
||||
header_row_idx: 1-based index of the header row
|
||||
column_map: Dict mapping 0-based column index to column name
|
||||
max_col_idx: 1-based index of the last valid column (for iter_rows boundary)
|
||||
"""
|
||||
# Store potential candidates: (row_index, non_empty_count, column_map)
|
||||
candidates: list[Candidate] = []
|
||||
|
||||
# Limit scan to avoid performance issues on huge files
|
||||
# We iterate manually to control the read scope
|
||||
for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1):
|
||||
# Filter out empty cells and build a temp map for this row
|
||||
# col_idx is 0-based
|
||||
row_map = {}
|
||||
for col_idx, cell_value in enumerate(row):
|
||||
if cell_value is not None and str(cell_value).strip():
|
||||
row_map[col_idx] = str(cell_value).strip().replace('"', '\\"')
|
||||
|
||||
if not row_map:
|
||||
continue
|
||||
|
||||
non_empty_count = len(row_map)
|
||||
|
||||
# Header selection heuristic (implemented):
|
||||
# - Prefer the first row with at least 2 non-empty columns.
|
||||
# - Fallback: choose the row with the most non-empty columns
|
||||
# (tie-breaker: smaller row index).
|
||||
candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map})
|
||||
|
||||
if not candidates:
|
||||
return 0, {}, 0
|
||||
|
||||
# Choose the best candidate header row.
|
||||
|
||||
best_candidate: Candidate | None = None
|
||||
|
||||
# Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback.
|
||||
|
||||
for cand in candidates:
|
||||
if cand["count"] >= 2:
|
||||
best_candidate = cand
|
||||
break
|
||||
|
||||
# Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns
|
||||
if not best_candidate:
|
||||
# Sort by count desc, then index asc
|
||||
candidates.sort(key=lambda x: (-x["count"], x["idx"]))
|
||||
best_candidate = candidates[0]
|
||||
|
||||
# Determine max_col_idx (1-based for openpyxl)
|
||||
# It is the index of the last valid column in our map + 1
|
||||
max_col_idx = max(best_candidate["map"].keys()) + 1
|
||||
|
||||
return best_candidate["idx"], best_candidate["map"], max_col_idx
|
||||
|
|
|
|||
|
|
@ -84,22 +84,45 @@ class WordExtractor(BaseExtractor):
|
|||
image_count = 0
|
||||
image_map = {}
|
||||
|
||||
for rel in doc.part.rels.values():
|
||||
for r_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.target_ref:
|
||||
image_count += 1
|
||||
if rel.is_external:
|
||||
url = rel.target_ref
|
||||
response = ssrf_proxy.get(url)
|
||||
if not self._is_valid_url(url):
|
||||
continue
|
||||
try:
|
||||
response = ssrf_proxy.get(url)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
|
||||
continue
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", ""))
|
||||
if image_ext is None:
|
||||
continue
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
storage.save(file_key, response.content)
|
||||
else:
|
||||
continue
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
# Use r_id as key for external images since target_part is undefined
|
||||
image_map[r_id] = f""
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
|
|
@ -110,27 +133,28 @@ class WordExtractor(BaseExtractor):
|
|||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
|
||||
storage.save(file_key, rel.target_part.blob)
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = f""
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
# Use target_part as key for internal images
|
||||
image_map[rel.target_part] = (
|
||||
f""
|
||||
)
|
||||
db.session.commit()
|
||||
return image_map
|
||||
|
||||
def _table_to_markdown(self, table, image_map):
|
||||
|
|
@ -186,11 +210,17 @@ class WordExtractor(BaseExtractor):
|
|||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
image_part = paragraph.part.rels[image_id].target_part
|
||||
|
||||
if image_part in image_map:
|
||||
image_link = image_map[image_part]
|
||||
paragraph_content.append(image_link)
|
||||
rel = paragraph.part.rels.get(image_id)
|
||||
if rel is None:
|
||||
continue
|
||||
# For external images, use image_id as key; for internal, use target_part
|
||||
if rel.is_external:
|
||||
if image_id in image_map:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
else:
|
||||
paragraph_content.append(run.text)
|
||||
return "".join(paragraph_content).strip()
|
||||
|
|
@ -227,6 +257,18 @@ class WordExtractor(BaseExtractor):
|
|||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
def append_image_link(image_id, has_drawing):
|
||||
"""Helper to append image link from image_map based on relationship type."""
|
||||
rel = doc.part.rels[image_id]
|
||||
if rel.is_external:
|
||||
if image_id in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
|
|
@ -243,10 +285,18 @@ class WordExtractor(BaseExtractor):
|
|||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
|
||||
)
|
||||
if embed_id:
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
rel = doc.part.rels.get(embed_id)
|
||||
if rel is not None and rel.is_external:
|
||||
# External image: use embed_id as key
|
||||
if embed_id in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[embed_id])
|
||||
else:
|
||||
# Internal image: use target_part as key
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
|
|
@ -261,9 +311,7 @@ class WordExtractor(BaseExtractor):
|
|||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
append_image_link(image_id, has_drawing)
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
|
|
@ -271,9 +319,7 @@ class WordExtractor(BaseExtractor):
|
|||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
append_image_link(image_id, has_drawing)
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
|
|
|||
|
|
@ -209,7 +209,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if all_multimodal_documents:
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
keyword = Keyword(dataset)
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
vector = Vector(dataset)
|
||||
if all_child_documents:
|
||||
vector.create(all_child_documents)
|
||||
if all_multimodal_documents:
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ class DatasetRetrieval:
|
|||
).all()
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attchment_info = File(
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
|
|
@ -280,7 +280,7 @@ class DatasetRetrieval:
|
|||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
)
|
||||
context_files.append(attchment_info)
|
||||
context_files.append(attachment_info)
|
||||
if show_retrieve_source:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
|
|
@ -592,111 +592,116 @@ class DatasetRetrieval:
|
|||
"""Handle retrieval end."""
|
||||
with flask_app.app_context():
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
segment_ids = []
|
||||
segment_index_node_ids = []
|
||||
if not dify_documents:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
with Session(db.engine) as session:
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
segment_id = None
|
||||
if (
|
||||
"doc_type" not in document.metadata
|
||||
or document.metadata.get("doc_type") == DocType.TEXT
|
||||
):
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment_id = child_chunk.segment_id
|
||||
elif (
|
||||
"doc_type" in document.metadata
|
||||
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||
):
|
||||
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
# Collect all document_ids and batch fetch DatasetDocuments
|
||||
document_ids = {
|
||||
doc.metadata["document_id"]
|
||||
for doc in dify_documents
|
||||
if doc.metadata and "document_id" in doc.metadata
|
||||
}
|
||||
if not document_ids:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
|
||||
dataset_docs = session.scalars(dataset_docs_stmt).all()
|
||||
dataset_doc_map = {str(doc.id): doc for doc in dataset_docs}
|
||||
|
||||
# Categorize documents by type and collect necessary IDs
|
||||
parent_child_text_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
parent_child_image_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
normal_text_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
normal_image_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
|
||||
for doc in dify_documents:
|
||||
if not doc.metadata or "document_id" not in doc.metadata:
|
||||
continue
|
||||
dataset_doc = dataset_doc_map.get(doc.metadata["document_id"])
|
||||
if not dataset_doc:
|
||||
continue
|
||||
|
||||
is_image = doc.metadata.get("doc_type") == DocType.IMAGE
|
||||
is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX
|
||||
|
||||
if is_parent_child:
|
||||
if is_image:
|
||||
parent_child_image_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
parent_child_text_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
if is_image:
|
||||
normal_image_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
normal_text_docs.append((doc, dataset_doc))
|
||||
|
||||
segment_ids_to_update: set[str] = set()
|
||||
|
||||
# Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks
|
||||
if parent_child_text_docs:
|
||||
index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata]
|
||||
if index_node_ids:
|
||||
child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids))
|
||||
child_chunks = session.scalars(child_chunks_stmt).all()
|
||||
child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks}
|
||||
for doc, _ in parent_child_text_docs:
|
||||
if doc.metadata:
|
||||
segment_id = child_chunk_map.get(doc.metadata["doc_id"])
|
||||
if segment_id:
|
||||
if segment_id not in segment_ids:
|
||||
segment_ids.append(segment_id)
|
||||
_ = (
|
||||
session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = None
|
||||
if (
|
||||
"doc_type" not in document.metadata
|
||||
or document.metadata.get("doc_type") == DocType.TEXT
|
||||
):
|
||||
if document.metadata["doc_id"] not in segment_index_node_ids:
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
segment_index_node_ids.append(document.metadata["doc_id"])
|
||||
segment_ids.append(segment.id)
|
||||
query = session.query(DocumentSegment).where(
|
||||
DocumentSegment.id == segment.id
|
||||
)
|
||||
elif (
|
||||
"doc_type" in document.metadata
|
||||
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||
):
|
||||
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
if segment_id not in segment_ids:
|
||||
segment_ids.append(segment_id)
|
||||
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
|
||||
if query:
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.where(
|
||||
DocumentSegment.dataset_id == document.metadata["dataset_id"]
|
||||
)
|
||||
segment_ids_to_update.add(str(segment_id))
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
# Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments
|
||||
if normal_text_docs:
|
||||
index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata]
|
||||
if index_node_ids:
|
||||
segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids))
|
||||
segments = session.scalars(segments_stmt).all()
|
||||
segment_map = {seg.index_node_id: seg.id for seg in segments}
|
||||
for doc, _ in normal_text_docs:
|
||||
if doc.metadata:
|
||||
segment_id = segment_map.get(doc.metadata["doc_id"])
|
||||
if segment_id:
|
||||
segment_ids_to_update.add(str(segment_id))
|
||||
|
||||
db.session.commit()
|
||||
# Process IMAGE documents - batch fetch SegmentAttachmentBindings
|
||||
all_image_docs = parent_child_image_docs + normal_image_docs
|
||||
if all_image_docs:
|
||||
attachment_ids = [
|
||||
doc.metadata["doc_id"]
|
||||
for doc, _ in all_image_docs
|
||||
if doc.metadata and doc.metadata.get("doc_id")
|
||||
]
|
||||
if attachment_ids:
|
||||
bindings_stmt = select(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.attachment_id.in_(attachment_ids)
|
||||
)
|
||||
bindings = session.scalars(bindings_stmt).all()
|
||||
segment_ids_to_update.update(str(binding.segment_id) for binding in bindings)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
# Batch update hit_count for all segments
|
||||
if segment_ids_to_update:
|
||||
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
|
||||
"""Send trace task if trace manager is available."""
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
def _on_query(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -101,6 +101,8 @@ class ToolFileMessageTransformer:
|
|||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
if not mimetype:
|
||||
mimetype = "application/octet-stream"
|
||||
# get filename from meta
|
||||
filename = meta.get("filename", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
|
|
|||
|
|
@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str:
|
|||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||
return re.sub(pattern, "", text)
|
||||
|
|
|
|||
|
|
@ -221,7 +221,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
WorkflowToolProvider.id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class OutputVariableEntity(BaseModel):
|
|||
"""
|
||||
|
||||
variable: str
|
||||
value_type: OutputVariableType
|
||||
value_type: OutputVariableType = OutputVariableType.ANY
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
|
|
|
|||
|
|
@ -412,16 +412,20 @@ class Executor:
|
|||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
# decode content safely
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
body_string += "\r\n"
|
||||
# Do not decode binary content; use a placeholder with file metadata instead.
|
||||
# Includes filename, size, and MIME type for better logging context.
|
||||
body_string += (
|
||||
f"<file_content_binary: '{file_entry[1][0] or 'unknown'}', "
|
||||
f"type='{file_entry[1][2] if len(file_entry[1]) > 2 else 'unknown'}', "
|
||||
f"size={len(content)} bytes>\r\n"
|
||||
)
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
# If content is bytes, do not decode it; show a placeholder with size.
|
||||
# Provides content size information for binary data without exposing the raw bytes.
|
||||
if isinstance(self.content, bytes):
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
body_string = f"<binary_content: size={len(self.content)} bytes>"
|
||||
else:
|
||||
body_string = self.content
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
|
|
|
|||
|
|
@ -334,6 +334,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -344,6 +345,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -694,7 +697,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
).all()
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attchment_info = File(
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
|
|
@ -708,7 +711,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
)
|
||||
context_files.append(attchment_info)
|
||||
context_files.append(attachment_info)
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip(),
|
||||
|
|
|
|||
|
|
@ -221,6 +221,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
|||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
|
|
|
|||
|
|
@ -15,4 +15,5 @@ def handle(sender: Dataset, **kwargs):
|
|||
dataset.index_struct,
|
||||
dataset.collection_binding_id,
|
||||
dataset.doc_form,
|
||||
dataset.pipeline_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,11 +9,21 @@ FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
|||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
# register blueprint routers
|
||||
def _apply_cors_once(bp, /, **cors_kwargs):
|
||||
"""Make CORS idempotent so blueprints can be reused across multiple app instances."""
|
||||
|
||||
if getattr(bp, "_dify_cors_applied", False):
|
||||
return
|
||||
|
||||
from flask_cors import CORS
|
||||
|
||||
CORS(bp, **cors_kwargs)
|
||||
bp._dify_cors_applied = True
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
# register blueprint routers
|
||||
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
|
|
@ -22,7 +32,7 @@ def init_app(app: DifyApp):
|
|||
from controllers.trigger import bp as trigger_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
CORS(
|
||||
_apply_cors_once(
|
||||
service_api_bp,
|
||||
allow_headers=list(SERVICE_API_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
|
|
@ -30,7 +40,7 @@ def init_app(app: DifyApp):
|
|||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(
|
||||
_apply_cors_once(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
|
|
@ -40,7 +50,7 @@ def init_app(app: DifyApp):
|
|||
)
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
CORS(
|
||||
_apply_cors_once(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
|
|
@ -50,7 +60,7 @@ def init_app(app: DifyApp):
|
|||
)
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(
|
||||
_apply_cors_once(
|
||||
files_bp,
|
||||
allow_headers=list(FILES_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
|
|
@ -62,7 +72,7 @@ def init_app(app: DifyApp):
|
|||
app.register_blueprint(mcp_bp)
|
||||
|
||||
# Register trigger blueprint with CORS for webhook calls
|
||||
CORS(
|
||||
_apply_cors_once(
|
||||
trigger_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ login_manager = flask_login.LoginManager()
|
|||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
# Skip authentication for documentation endpoints
|
||||
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
|
||||
# Skip authentication for documentation endpoints (only when Swagger is enabled)
|
||||
if dify_config.swagger_ui_enabled and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
|
||||
return None
|
||||
|
||||
auth_token = extract_access_token(request)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
def init_app(app):
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
|
|
@ -131,12 +131,28 @@ class ExternalApi(Api):
|
|||
}
|
||||
|
||||
def __init__(self, app: Blueprint | Flask, *args, **kwargs):
|
||||
import logging
|
||||
import os
|
||||
|
||||
kwargs.setdefault("authorizations", self._authorizations)
|
||||
kwargs.setdefault("security", "Bearer")
|
||||
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED
|
||||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
||||
|
||||
# Security: Use computed swagger_ui_enabled which respects DEPLOY_ENV
|
||||
swagger_enabled = dify_config.swagger_ui_enabled
|
||||
kwargs["add_specs"] = swagger_enabled
|
||||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if swagger_enabled else False
|
||||
|
||||
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
||||
super().__init__(app=None, *args, **kwargs)
|
||||
self.init_app(app, **kwargs)
|
||||
register_external_error_handlers(self)
|
||||
|
||||
# Security: Log warning when Swagger is enabled in production environment
|
||||
deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION")
|
||||
if swagger_enabled and deploy_env.upper() == "PRODUCTION":
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"SECURITY WARNING: Swagger UI is ENABLED in PRODUCTION environment. "
|
||||
"This may expose sensitive API documentation. "
|
||||
"Set SWAGGER_UI_ENABLED=false or remove the explicit setting to disable."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ def email(email):
|
|||
EmailStr = Annotated[str, AfterValidator(email)]
|
||||
|
||||
|
||||
def uuid_value(value):
|
||||
def uuid_value(value: Any) -> str:
|
||||
if value == "":
|
||||
return str(value)
|
||||
|
||||
|
|
@ -215,7 +215,11 @@ def generate_text_hash(text: str) -> str:
|
|||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
|
||||
return Response(
|
||||
response=json.dumps(jsonable_encoder(response)),
|
||||
status=200,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
else:
|
||||
|
||||
def generate() -> Generator:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""empty message
|
||||
"""mysql adaptation
|
||||
|
||||
Revision ID: 09cfdda155d1
|
||||
Revises: 669ffd70119c
|
||||
|
|
@ -97,11 +97,31 @@ def downgrade():
|
|||
batch_op.alter_column('include_plugins',
|
||||
existing_type=sa.JSON(),
|
||||
type_=postgresql.ARRAY(sa.VARCHAR(length=255)),
|
||||
existing_nullable=False)
|
||||
existing_nullable=False,
|
||||
postgresql_using="""
|
||||
COALESCE(
|
||||
regexp_replace(
|
||||
replace(replace(include_plugins::text, '[', '{'), ']', '}'),
|
||||
'"',
|
||||
'',
|
||||
'g'
|
||||
)::varchar(255)[],
|
||||
ARRAY[]::varchar(255)[]
|
||||
)""")
|
||||
batch_op.alter_column('exclude_plugins',
|
||||
existing_type=sa.JSON(),
|
||||
type_=postgresql.ARRAY(sa.VARCHAR(length=255)),
|
||||
existing_nullable=False)
|
||||
existing_nullable=False,
|
||||
postgresql_using="""
|
||||
COALESCE(
|
||||
regexp_replace(
|
||||
replace(replace(exclude_plugins::text, '[', '{'), ']', '}'),
|
||||
'"',
|
||||
'',
|
||||
'g'
|
||||
)::varchar(255)[],
|
||||
ARRAY[]::varchar(255)[]
|
||||
)""")
|
||||
|
||||
with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op:
|
||||
batch_op.alter_column('external_knowledge_id',
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class Dataset(Base):
|
|||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
is_multimodal = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_multimodal = mapped_column(sa.Boolean, default=False, nullable=False, server_default=db.text("false"))
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
|
|
|
|||
|
|
@ -111,7 +111,11 @@ class App(Base):
|
|||
else:
|
||||
app_model_config = self.app_model_config
|
||||
if app_model_config:
|
||||
return app_model_config.pre_prompt
|
||||
pre_prompt = app_model_config.pre_prompt or ""
|
||||
# Truncate to 200 characters with ellipsis if using prompt as description
|
||||
if len(pre_prompt) > 200:
|
||||
return pre_prompt[:200] + "..."
|
||||
return pre_prompt
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
|
@ -259,7 +263,7 @@ class App(Base):
|
|||
provider_id = tool.get("provider_id", "")
|
||||
|
||||
if provider_type == ToolProviderType.API:
|
||||
if uuid.UUID(provider_id) not in existing_api_providers:
|
||||
if provider_id not in existing_api_providers:
|
||||
deleted_tools.append(
|
||||
{
|
||||
"type": ToolProviderType.API,
|
||||
|
|
@ -835,7 +839,29 @@ class Conversation(Base):
|
|||
|
||||
@property
|
||||
def status_count(self):
|
||||
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
# Get all messages with workflow_run_id for this conversation
|
||||
messages = db.session.scalars(
|
||||
select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None))
|
||||
).all()
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Batch load all workflow runs in a single query, filtered by this conversation's app_id
|
||||
workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id]
|
||||
workflow_runs = {}
|
||||
|
||||
if workflow_run_ids:
|
||||
workflow_runs_query = db.session.scalars(
|
||||
select(WorkflowRun).where(
|
||||
WorkflowRun.id.in_(workflow_run_ids),
|
||||
WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id
|
||||
)
|
||||
).all()
|
||||
workflow_runs = {run.id: run for run in workflow_runs_query}
|
||||
|
||||
status_counts = {
|
||||
WorkflowExecutionStatus.RUNNING: 0,
|
||||
WorkflowExecutionStatus.SUCCEEDED: 0,
|
||||
|
|
@ -845,18 +871,24 @@ class Conversation(Base):
|
|||
}
|
||||
|
||||
for message in messages:
|
||||
if message.workflow_run:
|
||||
status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1
|
||||
# Guard against None to satisfy type checker and avoid invalid dict lookups
|
||||
if message.workflow_run_id is None:
|
||||
continue
|
||||
workflow_run = workflow_runs.get(message.workflow_run_id)
|
||||
if not workflow_run:
|
||||
continue
|
||||
|
||||
return (
|
||||
{
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
if messages
|
||||
else None
|
||||
)
|
||||
try:
|
||||
status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1
|
||||
except (ValueError, KeyError):
|
||||
# Handle invalid status values gracefully
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
|
||||
@property
|
||||
def first_message(self):
|
||||
|
|
@ -1255,13 +1287,9 @@ class Message(Base):
|
|||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"model_provider": self.model_provider,
|
||||
"model_id": self.model_id,
|
||||
"inputs": self.inputs,
|
||||
"query": self.query,
|
||||
"message_tokens": self.message_tokens,
|
||||
"answer_tokens": self.answer_tokens,
|
||||
"provider_response_latency": self.provider_response_latency,
|
||||
"total_price": self.total_price,
|
||||
"message": self.message,
|
||||
"answer": self.answer,
|
||||
|
|
@ -1283,12 +1311,8 @@ class Message(Base):
|
|||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
conversation_id=data["conversation_id"],
|
||||
model_provider=data.get("model_provider"),
|
||||
model_id=data["model_id"],
|
||||
inputs=data["inputs"],
|
||||
message_tokens=data.get("message_tokens", 0),
|
||||
answer_tokens=data.get("answer_tokens", 0),
|
||||
provider_response_latency=data.get("provider_response_latency", 0.0),
|
||||
total_price=data["total_price"],
|
||||
query=data["query"],
|
||||
message=data["message"],
|
||||
|
|
|
|||
|
|
@ -907,19 +907,29 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||
@property
|
||||
def extras(self) -> dict[str, Any]:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
if self.execution_metadata_dict:
|
||||
if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict:
|
||||
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
|
||||
execution_metadata = self.execution_metadata_dict
|
||||
if execution_metadata:
|
||||
if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata:
|
||||
tool_info: dict[str, Any] = execution_metadata["tool_info"]
|
||||
extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_type=tool_info["provider_type"],
|
||||
provider_id=tool_info["provider_id"],
|
||||
)
|
||||
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict:
|
||||
datasource_info = self.execution_metadata_dict["datasource_info"]
|
||||
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata:
|
||||
trigger_info = execution_metadata["trigger_info"] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.10.1"
|
||||
version = "1.11.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
|
@ -151,7 +151,7 @@ dev = [
|
|||
"types-pywin32~=310.0.0",
|
||||
"types-pyyaml~=6.0.12",
|
||||
"types-regex~=2024.11.6",
|
||||
"types-shapely~=2.0.0",
|
||||
"types-shapely~=2.1.0",
|
||||
"types-simplejson>=3.20.0",
|
||||
"types-six>=1.17.0",
|
||||
"types-tensorflow>=2.18.0",
|
||||
|
|
@ -216,6 +216,7 @@ vdb = [
|
|||
"pymochow==2.2.9",
|
||||
"pyobvector~=0.2.17",
|
||||
"qdrant-client==1.9.0",
|
||||
"intersystems-irispython>=5.1.0",
|
||||
"tablestore==6.3.7",
|
||||
"tcvectordb~=1.6.4",
|
||||
"tidb-vector==0.0.9",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[pytest]
|
||||
addopts = --cov=./api --cov-report=json --cov-report=xml
|
||||
addopts = --cov=./api --cov-report=json
|
||||
env =
|
||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class ConversationService:
|
|||
app_model: App,
|
||||
conversation_id: str,
|
||||
user: Union[Account, EndUser] | None,
|
||||
name: str,
|
||||
name: str | None,
|
||||
auto_generate: bool,
|
||||
):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
|
|
|||
|
|
@ -673,6 +673,8 @@ class DatasetService:
|
|||
Returns:
|
||||
str: Action to perform ('add', 'remove', 'update', or None)
|
||||
"""
|
||||
if "indexing_technique" not in data:
|
||||
return None
|
||||
if dataset.indexing_technique != data["indexing_technique"]:
|
||||
if data["indexing_technique"] == "economy":
|
||||
# Remove embedding model configuration for economy mode
|
||||
|
|
@ -1634,6 +1636,20 @@ class DocumentService:
|
|||
return [], ""
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.flush()
|
||||
else:
|
||||
# Fallback when no process_rule provided in knowledge_config:
|
||||
# 1) reuse dataset.latest_process_rule if present
|
||||
# 2) otherwise create an automatic rule
|
||||
dataset_process_rule = getattr(dataset, "latest_process_rule", None)
|
||||
if not dataset_process_rule:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode="automatic",
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.flush()
|
||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
|
|
@ -1645,65 +1661,67 @@ class DocumentService:
|
|||
if not knowledge_config.data_source.info_list.file_info_list:
|
||||
raise ValueError("File source info is required")
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id.in_(upload_file_list),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if len(files) != len(set(upload_file_list)):
|
||||
raise FileNotExistsError("One or more files not found.")
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
file_names = [file.name for file in files]
|
||||
db_documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == "upload_file",
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
documents_map = {document.name: document for document in db_documents}
|
||||
for file in files:
|
||||
data_source_info: dict[str, str | bool] = {
|
||||
"upload_file_id": file_id,
|
||||
"upload_file_id": file.id,
|
||||
}
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
)
|
||||
.first()
|
||||
document = documents_map.get(file.name)
|
||||
if knowledge_config.duplicate and document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
else:
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file.name,
|
||||
batch,
|
||||
)
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
|
|
@ -2799,20 +2817,20 @@ class SegmentService:
|
|||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
|
||||
# save vector index
|
||||
try:
|
||||
VectorService.create_segments_vector(
|
||||
[args["keywords"]], [segment_document], dataset, document.doc_form
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
|
||||
return segment
|
||||
# save vector index
|
||||
try:
|
||||
keywords = args.get("keywords")
|
||||
keywords_list = [keywords] if keywords is not None else None
|
||||
VectorService.create_segments_vector(keywords_list, [segment_document], dataset, document.doc_form)
|
||||
except Exception as e:
|
||||
logger.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
|
||||
return segment
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -29,8 +29,14 @@ def get_current_user():
|
|||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
|
||||
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
|
||||
try:
|
||||
user_object = current_user._get_current_object()
|
||||
except AttributeError:
|
||||
# Handle case where current_user might not be a LocalProxy in test environments
|
||||
user_object = current_user
|
||||
|
||||
if not isinstance(user_object, (Account, EndUser)):
|
||||
raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}")
|
||||
return current_user
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -178,8 +178,8 @@ class HitTestingService:
|
|||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args["query"]
|
||||
attachment_ids = args["attachment_ids"]
|
||||
query = args.get("query")
|
||||
attachment_ids = args.get("attachment_ids")
|
||||
|
||||
if not attachment_ids and not query:
|
||||
raise ValueError("Query or attachment_ids is required")
|
||||
|
|
|
|||
|
|
@ -70,9 +70,28 @@ class ModelProviderService:
|
|||
continue
|
||||
|
||||
provider_config = provider_configuration.custom_configuration.provider
|
||||
model_config = provider_configuration.custom_configuration.models
|
||||
models = provider_configuration.custom_configuration.models
|
||||
can_added_models = provider_configuration.custom_configuration.can_added_models
|
||||
|
||||
# IMPORTANT: Never expose decrypted credentials in the provider list API.
|
||||
# Sanitize custom model configurations by dropping the credentials payload.
|
||||
sanitized_model_config = []
|
||||
if models:
|
||||
from core.entities.provider_entities import CustomModelConfiguration # local import to avoid cycles
|
||||
|
||||
for model in models:
|
||||
sanitized_model_config.append(
|
||||
CustomModelConfiguration(
|
||||
model=model.model,
|
||||
model_type=model.model_type,
|
||||
credentials=None, # strip secrets from list view
|
||||
current_credential_id=model.current_credential_id,
|
||||
current_credential_name=model.current_credential_name,
|
||||
available_model_credentials=model.available_model_credentials,
|
||||
unadded_to_model_list=model.unadded_to_model_list,
|
||||
)
|
||||
)
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_configuration.provider.provider,
|
||||
|
|
@ -95,7 +114,7 @@ class ModelProviderService:
|
|||
current_credential_id=getattr(provider_config, "current_credential_id", None),
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
custom_models=model_config,
|
||||
custom_models=sanitized_model_config,
|
||||
can_added_models=can_added_models,
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
|||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import WorkflowType
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
|
|
@ -18,9 +19,11 @@ from models.dataset import (
|
|||
DatasetQuery,
|
||||
Document,
|
||||
DocumentSegment,
|
||||
Pipeline,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,6 +37,7 @@ def clean_dataset_task(
|
|||
index_struct: str,
|
||||
collection_binding_id: str,
|
||||
doc_form: str,
|
||||
pipeline_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Clean dataset when dataset deleted.
|
||||
|
|
@ -135,6 +139,14 @@ def clean_dataset_task(
|
|||
# delete dataset metadata
|
||||
db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
|
||||
db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
|
||||
# delete pipeline and workflow
|
||||
if pipeline_id:
|
||||
db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
|
||||
db.session.query(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
).delete()
|
||||
# delete files
|
||||
if documents:
|
||||
for document in documents:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
|
@ -14,7 +15,8 @@ logger = logging.getLogger(__name__)
|
|||
def delete_account_task(account_id):
|
||||
account = db.session.query(Account).where(Account.id == account_id).first()
|
||||
try:
|
||||
BillingService.delete_account(account_id)
|
||||
if dify_config.BILLING_ENABLED:
|
||||
BillingService.delete_account(account_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete account %s from billing service.", account_id)
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -0,0 +1,127 @@
|
|||
app:
|
||||
description: 'End node without value_type field reproduction'
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: workflow
|
||||
name: end_node_without_value_type_field_reproduction
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies: []
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
fileUploadConfig:
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
image_file_batch_limit: 10
|
||||
image_file_size_limit: 10
|
||||
single_chunk_attachment_limit: 10
|
||||
video_file_size_limit: 100
|
||||
workflow_file_upload_limit: 10
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: start
|
||||
targetType: end
|
||||
id: 1765423445456-source-1765423454810-target
|
||||
source: '1765423445456'
|
||||
sourceHandle: source
|
||||
target: '1765423454810'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: 用户输入
|
||||
type: start
|
||||
variables:
|
||||
- default: ''
|
||||
hint: ''
|
||||
label: query
|
||||
max_length: 48
|
||||
options: []
|
||||
placeholder: ''
|
||||
required: true
|
||||
type: text-input
|
||||
variable: query
|
||||
height: 109
|
||||
id: '1765423445456'
|
||||
position:
|
||||
x: -48
|
||||
y: 261
|
||||
positionAbsolute:
|
||||
x: -48
|
||||
y: 261
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
outputs:
|
||||
- value_selector:
|
||||
- '1765423445456'
|
||||
- query
|
||||
variable: query
|
||||
selected: true
|
||||
title: 输出
|
||||
type: end
|
||||
height: 88
|
||||
id: '1765423454810'
|
||||
position:
|
||||
x: 382
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 382
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: 139
|
||||
y: -135
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
|
|
@ -55,7 +55,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
|||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, iris
|
||||
VECTOR_STORE=weaviate
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
|
|
@ -64,6 +64,20 @@ WEAVIATE_GRPC_ENABLED=false
|
|||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
# InterSystems IRIS configuration
|
||||
IRIS_HOST=localhost
|
||||
IRIS_SUPER_SERVER_PORT=1972
|
||||
IRIS_WEB_SERVER_PORT=52773
|
||||
IRIS_USER=_SYSTEM
|
||||
IRIS_PASSWORD=Dify@1234
|
||||
IRIS_DATABASE=USER
|
||||
IRIS_SCHEMA=dify
|
||||
IRIS_CONNECTION_URL=
|
||||
IRIS_MIN_CONNECTION=1
|
||||
IRIS_MAX_CONNECTION=3
|
||||
IRIS_TEXT_INDEX=true
|
||||
IRIS_TEXT_INDEX_LANGUAGE=en
|
||||
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import secrets
|
||||
|
|
@ -32,6 +33,10 @@ def _load_env():
|
|||
|
||||
|
||||
_load_env()
|
||||
# Override storage root to tmp to avoid polluting repo during local runs
|
||||
os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage"
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
os.environ.setdefault("OPENDAL_SCHEME", "fs")
|
||||
|
||||
_CACHED_APP = create_app()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
"""Integration tests for IRIS vector database."""
|
||||
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class IrisVectorTest(AbstractVectorTest):
|
||||
"""Test suite for IRIS vector store implementation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize IRIS vector test with hardcoded test configuration.
|
||||
|
||||
Note: Uses 'host.docker.internal' to connect from DevContainer to
|
||||
host OS Docker, or 'localhost' when running directly on host OS.
|
||||
"""
|
||||
super().__init__()
|
||||
self.vector = IrisVector(
|
||||
collection_name=self.collection_name,
|
||||
config=IrisVectorConfig(
|
||||
IRIS_HOST="host.docker.internal",
|
||||
IRIS_SUPER_SERVER_PORT=1972,
|
||||
IRIS_USER="_SYSTEM",
|
||||
IRIS_PASSWORD="Dify@1234",
|
||||
IRIS_DATABASE="USER",
|
||||
IRIS_SCHEMA="dify",
|
||||
IRIS_CONNECTION_URL=None,
|
||||
IRIS_MIN_CONNECTION=1,
|
||||
IRIS_MAX_CONNECTION=3,
|
||||
IRIS_TEXT_INDEX=True,
|
||||
IRIS_TEXT_INDEX_LANGUAGE="en",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_iris_vector(setup_mock_redis) -> None:
|
||||
"""Run all IRIS vector store tests.
|
||||
|
||||
Args:
|
||||
setup_mock_redis: Pytest fixture for mock Redis setup
|
||||
"""
|
||||
IrisVectorTest().run_all_tests()
|
||||
|
|
@ -138,9 +138,9 @@ class DifyTestContainers:
|
|||
logger.warning("Failed to create plugin database: %s", e)
|
||||
|
||||
# Set up storage environment variables
|
||||
os.environ["STORAGE_TYPE"] = "opendal"
|
||||
os.environ["OPENDAL_SCHEME"] = "fs"
|
||||
os.environ["OPENDAL_FS_ROOT"] = "storage"
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
os.environ.setdefault("OPENDAL_SCHEME", "fs")
|
||||
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
|
||||
|
||||
# Start Redis container for caching and session management
|
||||
# Redis is used for storing session data, cache entries, and temporary data
|
||||
|
|
@ -348,6 +348,13 @@ def _create_app_with_containers() -> Flask:
|
|||
"""
|
||||
logger.info("Creating Flask application with test container configuration...")
|
||||
|
||||
# Ensure Redis client reconnects to the containerized Redis (no auth)
|
||||
from extensions import ext_redis
|
||||
|
||||
ext_redis.redis_client._client = None
|
||||
os.environ["REDIS_USERNAME"] = ""
|
||||
os.environ["REDIS_PASSWORD"] = ""
|
||||
|
||||
# Re-create the config after environment variables have been set
|
||||
from configs import dify_config
|
||||
|
||||
|
|
@ -486,3 +493,29 @@ def db_session_with_containers(flask_app_with_containers) -> Generator[Session,
|
|||
finally:
|
||||
session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
|
||||
@pytest.fixture(scope="package", autouse=True)
|
||||
def mock_ssrf_proxy_requests():
|
||||
"""
|
||||
Avoid outbound network during containerized tests by stubbing SSRF proxy helpers.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
|
||||
def _fake_request(method, url, **kwargs):
|
||||
request = httpx.Request(method=method, url=url)
|
||||
return httpx.Response(200, request=request, content=b"")
|
||||
|
||||
with (
|
||||
patch("core.helper.ssrf_proxy.make_request", side_effect=_fake_request),
|
||||
patch("core.helper.ssrf_proxy.get", side_effect=lambda url, **kw: _fake_request("GET", url, **kw)),
|
||||
patch("core.helper.ssrf_proxy.post", side_effect=lambda url, **kw: _fake_request("POST", url, **kw)),
|
||||
patch("core.helper.ssrf_proxy.put", side_effect=lambda url, **kw: _fake_request("PUT", url, **kw)),
|
||||
patch("core.helper.ssrf_proxy.patch", side_effect=lambda url, **kw: _fake_request("PATCH", url, **kw)),
|
||||
patch("core.helper.ssrf_proxy.delete", side_effect=lambda url, **kw: _fake_request("DELETE", url, **kw)),
|
||||
patch("core.helper.ssrf_proxy.head", side_effect=lambda url, **kw: _fake_request("HEAD", url, **kw)),
|
||||
):
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -240,8 +240,7 @@ class TestShardedRedisBroadcastChannelIntegration:
|
|||
for future in as_completed(producer_futures, timeout=30.0):
|
||||
sent_msgs.update(future.result())
|
||||
|
||||
subscription.close()
|
||||
consumer_received_msgs = consumer_future.result(timeout=30.0)
|
||||
consumer_received_msgs = consumer_future.result(timeout=60.0)
|
||||
|
||||
assert sent_msgs == consumer_received_msgs
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
Fixtures for trigger integration tests.
|
||||
|
||||
This module provides fixtures for creating test data (tenant, account, app)
|
||||
and mock objects used across trigger-related tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import App
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[Tenant, Account], None, None]:
|
||||
"""
|
||||
Create a tenant and account for testing.
|
||||
|
||||
This fixture creates a tenant, account, and their association,
|
||||
then cleans up after the test completes.
|
||||
|
||||
Yields:
|
||||
tuple[Tenant, Account]: The created tenant and account
|
||||
"""
|
||||
tenant = Tenant(name="trigger-e2e")
|
||||
account = Account(name="tester", email="tester@example.com", interface_language="en-US")
|
||||
db_session_with_containers.add_all([tenant, account])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER.value)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
yield tenant, account
|
||||
|
||||
# Cleanup
|
||||
db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
||||
db_session_with_containers.query(Account).filter_by(id=account.id).delete()
|
||||
db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model(
|
||||
db_session_with_containers: Session, tenant_and_account: tuple[Tenant, Account]
|
||||
) -> Generator[App, None, None]:
|
||||
"""
|
||||
Create an app for testing.
|
||||
|
||||
This fixture creates a workflow app associated with the tenant and account,
|
||||
then cleans up after the test completes.
|
||||
|
||||
Yields:
|
||||
App: The created app
|
||||
"""
|
||||
tenant, account = tenant_and_account
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="trigger-app",
|
||||
description="trigger e2e",
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#FFEAD5",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=1000,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
yield app
|
||||
|
||||
# Cleanup - delete related records first
|
||||
from models.trigger import (
|
||||
AppTrigger,
|
||||
TriggerSubscription,
|
||||
WorkflowPluginTrigger,
|
||||
WorkflowSchedulePlan,
|
||||
WorkflowTriggerLog,
|
||||
WorkflowWebhookTrigger,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
|
||||
db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete()
|
||||
db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(App).filter_by(id=app.id).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
class MockCeleryGroup:
|
||||
"""Mock for celery group() function that collects dispatched tasks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.collected: list[dict[str, Any]] = []
|
||||
self._applied = False
|
||||
|
||||
def __call__(self, items: Any) -> MockCeleryGroup:
|
||||
self.collected = list(items)
|
||||
return self
|
||||
|
||||
def apply_async(self) -> None:
|
||||
self._applied = True
|
||||
|
||||
@property
|
||||
def applied(self) -> bool:
|
||||
return self._applied
|
||||
|
||||
|
||||
class MockCelerySignature:
|
||||
"""Mock for celery task signature that returns task info dict."""
|
||||
|
||||
def s(self, schedule_id: str) -> dict[str, str]:
|
||||
return {"schedule_id": schedule_id}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_celery_group() -> MockCeleryGroup:
|
||||
"""
|
||||
Provide a mock celery group for testing task dispatch.
|
||||
|
||||
Returns:
|
||||
MockCeleryGroup: Mock group that collects dispatched tasks
|
||||
"""
|
||||
return MockCeleryGroup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_celery_signature() -> MockCelerySignature:
|
||||
"""
|
||||
Provide a mock celery signature for testing task dispatch.
|
||||
|
||||
Returns:
|
||||
MockCelerySignature: Mock signature generator
|
||||
"""
|
||||
return MockCelerySignature()
|
||||
|
||||
|
||||
class MockPluginSubscription:
|
||||
"""Mock plugin subscription for testing plugin triggers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subscription_id: str = "sub-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
provider_id: str = "provider-1",
|
||||
) -> None:
|
||||
self.id = subscription_id
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.credentials: dict[str, str] = {"token": "secret"}
|
||||
self.credential_type = "api-key"
|
||||
|
||||
def to_entity(self) -> MockPluginSubscription:
|
||||
return self
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_plugin_subscription() -> MockPluginSubscription:
|
||||
"""
|
||||
Provide a mock plugin subscription for testing.
|
||||
|
||||
Returns:
|
||||
MockPluginSubscription: Mock subscription instance
|
||||
"""
|
||||
return MockPluginSubscription()
|
||||
|
|
@ -0,0 +1,911 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from flask import Flask, Response
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.trigger.debug import event_selectors
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, Tenant
|
||||
from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import (
|
||||
AppTrigger,
|
||||
TriggerSubscription,
|
||||
WorkflowPluginTrigger,
|
||||
WorkflowSchedulePlan,
|
||||
WorkflowTriggerLog,
|
||||
WorkflowWebhookTrigger,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
from schedule import workflow_schedule_task
|
||||
from schedule.workflow_schedule_task import poll_workflow_schedules
|
||||
from services import feature_service as feature_service_module
|
||||
from services.trigger import webhook_service
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks import trigger_processing_tasks
|
||||
|
||||
from .conftest import MockCeleryGroup, MockCelerySignature, MockPluginSubscription
|
||||
|
||||
# Test constants
|
||||
WEBHOOK_ID_PRODUCTION = "wh1234567890123456789012"
|
||||
WEBHOOK_ID_DEBUG = "whdebug1234567890123456"
|
||||
TEST_TRIGGER_URL = "https://trigger.example.com/base"
|
||||
|
||||
|
||||
def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str:
|
||||
"""Build a minimal workflow graph JSON for testing."""
|
||||
node_data: dict[str, Any] = {"type": trigger_type.value, "title": "trigger"}
|
||||
if trigger_type == NodeType.TRIGGER_WEBHOOK:
|
||||
node_data.update(
|
||||
{
|
||||
"method": "POST",
|
||||
"content_type": "application/json",
|
||||
"headers": [],
|
||||
"params": [],
|
||||
"body": [],
|
||||
}
|
||||
)
|
||||
graph = {
|
||||
"nodes": [
|
||||
{"id": root_node_id, "data": node_data},
|
||||
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
|
||||
],
|
||||
"edges": [{"source": root_node_id, "target": "answer-1", "sourceHandle": "success"}],
|
||||
}
|
||||
return json.dumps(graph)
|
||||
|
||||
|
||||
def test_publish_blocks_start_and_trigger_coexistence(
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Publishing should fail when both start and trigger nodes coexist."""
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
graph = {
|
||||
"nodes": [
|
||||
{"id": "start", "data": {"type": NodeType.START.value}},
|
||||
{"id": "trig", "data": {"type": NodeType.TRIGGER_WEBHOOK.value}},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
draft_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(draft_workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
monkeypatch.setattr(
|
||||
feature_service_module.FeatureService,
|
||||
"get_system_features",
|
||||
classmethod(lambda _cls: SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False))),
|
||||
)
|
||||
monkeypatch.setattr("services.workflow_service.dify_config", SimpleNamespace(BILLING_ENABLED=False))
|
||||
|
||||
with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"):
|
||||
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app_model, account=account)
|
||||
|
||||
|
||||
def test_trigger_url_uses_config_base(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""TRIGGER_URL config should be reflected in generated webhook and plugin endpoints."""
|
||||
original_url = getattr(dify_config, "TRIGGER_URL", None)
|
||||
|
||||
try:
|
||||
monkeypatch.setattr(dify_config, "TRIGGER_URL", TEST_TRIGGER_URL)
|
||||
endpoint_module = importlib.reload(importlib.import_module("core.trigger.utils.endpoint"))
|
||||
|
||||
assert (
|
||||
endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION)
|
||||
== f"{TEST_TRIGGER_URL}/triggers/webhook/{WEBHOOK_ID_PRODUCTION}"
|
||||
)
|
||||
assert (
|
||||
endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION, True)
|
||||
== f"{TEST_TRIGGER_URL}/triggers/webhook-debug/{WEBHOOK_ID_PRODUCTION}"
|
||||
)
|
||||
assert (
|
||||
endpoint_module.generate_plugin_trigger_endpoint_url("end-1") == f"{TEST_TRIGGER_URL}/triggers/plugin/end-1"
|
||||
)
|
||||
finally:
|
||||
# Restore original config and reload module
|
||||
if original_url is not None:
|
||||
monkeypatch.setattr(dify_config, "TRIGGER_URL", original_url)
|
||||
importlib.reload(importlib.import_module("core.trigger.utils.endpoint"))
|
||||
|
||||
|
||||
def test_webhook_trigger_creates_trigger_log(
|
||||
test_client_with_containers: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Production webhook trigger should create a trigger log in the database."""
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
webhook_node_id = "webhook-node"
|
||||
graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK)
|
||||
published_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.version_from_datetime(naive_utc_now()),
|
||||
graph=graph_json,
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(published_workflow)
|
||||
app_model.workflow_id = published_workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
app_id=app_model.id,
|
||||
node_id=webhook_node_id,
|
||||
tenant_id=tenant.id,
|
||||
webhook_id=WEBHOOK_ID_PRODUCTION,
|
||||
created_by=account.id,
|
||||
)
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
node_id=webhook_node_id,
|
||||
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
title="webhook",
|
||||
)
|
||||
|
||||
db_session_with_containers.add_all([webhook_trigger, app_trigger])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace:
|
||||
log = WorkflowTriggerLog(
|
||||
tenant_id=trigger_data.tenant_id,
|
||||
app_id=trigger_data.app_id,
|
||||
workflow_id=trigger_data.workflow_id,
|
||||
root_node_id=trigger_data.root_node_id,
|
||||
trigger_metadata=trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}",
|
||||
trigger_type=trigger_data.trigger_type,
|
||||
workflow_run_id=None,
|
||||
outputs=None,
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
inputs=json.dumps(dict(trigger_data.inputs)),
|
||||
status=WorkflowTriggerStatus.SUCCEEDED,
|
||||
error="",
|
||||
queue_name="triggered_workflow_dispatcher",
|
||||
celery_task_id="celery-test",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
session.add(log)
|
||||
session.commit()
|
||||
return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test")
|
||||
|
||||
monkeypatch.setattr(
|
||||
webhook_service.AsyncWorkflowService,
|
||||
"trigger_workflow_async",
|
||||
_fake_trigger_workflow_async,
|
||||
)
|
||||
|
||||
response = test_client_with_containers.post(f"/triggers/webhook/{webhook_trigger.webhook_id}", json={"foo": "bar"})
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
|
||||
assert logs, "Webhook trigger should create trigger log"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schedule_type", ["visual", "cron"])
|
||||
def test_schedule_poll_dispatches_due_plan(
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
mock_celery_group: MockCeleryGroup,
|
||||
mock_celery_signature: MockCelerySignature,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
schedule_type: str,
|
||||
) -> None:
|
||||
"""Schedule plans (both visual and cron) should be polled and dispatched when due."""
|
||||
tenant, _ = tenant_and_account
|
||||
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
node_id=f"schedule-{schedule_type}",
|
||||
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
title=f"schedule-{schedule_type}",
|
||||
)
|
||||
plan = WorkflowSchedulePlan(
|
||||
app_id=app_model.id,
|
||||
node_id=f"schedule-{schedule_type}",
|
||||
tenant_id=tenant.id,
|
||||
cron_expression="* * * * *",
|
||||
timezone="UTC",
|
||||
next_run_at=naive_utc_now() - timedelta(minutes=1),
|
||||
)
|
||||
db_session_with_containers.add_all([app_trigger, plan])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
next_time = naive_utc_now() + timedelta(hours=1)
|
||||
monkeypatch.setattr(workflow_schedule_task, "calculate_next_run_at", lambda *_args, **_kwargs: next_time)
|
||||
monkeypatch.setattr(workflow_schedule_task, "group", mock_celery_group)
|
||||
monkeypatch.setattr(workflow_schedule_task, "run_schedule_trigger", mock_celery_signature)
|
||||
|
||||
poll_workflow_schedules()
|
||||
|
||||
assert mock_celery_group.collected, f"Should dispatch signatures for due {schedule_type} schedules"
|
||||
scheduled_ids = {sig["schedule_id"] for sig in mock_celery_group.collected}
|
||||
assert plan.id in scheduled_ids
|
||||
|
||||
|
||||
def test_schedule_visual_debug_poll_generates_event(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Visual mode schedule node should generate event in single-step debug."""
|
||||
base_now = naive_utc_now()
|
||||
monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now)
|
||||
monkeypatch.setattr(
|
||||
event_selectors,
|
||||
"calculate_next_run_at",
|
||||
lambda *_args, **_kwargs: base_now - timedelta(minutes=1),
|
||||
)
|
||||
node_config = {
|
||||
"id": "schedule-visual",
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_SCHEDULE.value,
|
||||
"mode": "visual",
|
||||
"frequency": "daily",
|
||||
"visual_config": {"time": "3:00 PM"},
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
poller = event_selectors.ScheduleTriggerDebugEventPoller(
|
||||
tenant_id="tenant",
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
node_config=node_config,
|
||||
node_id="schedule-visual",
|
||||
)
|
||||
event = poller.poll()
|
||||
assert event is not None
|
||||
assert event.workflow_args["inputs"] == {}
|
||||
|
||||
|
||||
def test_plugin_trigger_dispatches_and_debug_events(
|
||||
test_client_with_containers: FlaskClient,
|
||||
mock_plugin_subscription: MockPluginSubscription,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Plugin trigger endpoint should dispatch events and generate debug events."""
|
||||
endpoint_id = "1cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
|
||||
|
||||
debug_events: list[dict[str, Any]] = []
|
||||
dispatched_payloads: list[dict[str, Any]] = []
|
||||
|
||||
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
|
||||
dispatch_data = {
|
||||
"user_id": "end-user",
|
||||
"tenant_id": mock_plugin_subscription.tenant_id,
|
||||
"endpoint_id": _endpoint_id,
|
||||
"provider_id": mock_plugin_subscription.provider_id,
|
||||
"subscription_id": mock_plugin_subscription.id,
|
||||
"timestamp": int(time.time()),
|
||||
"events": ["created", "updated"],
|
||||
"request_id": f"req-{_endpoint_id}",
|
||||
}
|
||||
trigger_processing_tasks.dispatch_triggered_workflows_async.delay(dispatch_data)
|
||||
return Response("ok", status=202)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.process_endpoint",
|
||||
staticmethod(_fake_process_endpoint),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
trigger_processing_tasks.TriggerDebugEventBus,
|
||||
"dispatch",
|
||||
staticmethod(lambda **kwargs: debug_events.append(kwargs) or 1),
|
||||
)
|
||||
|
||||
def _fake_delay(dispatch_data: dict[str, Any]) -> None:
|
||||
dispatched_payloads.append(dispatch_data)
|
||||
trigger_processing_tasks.dispatch_trigger_debug_event(
|
||||
events=dispatch_data["events"],
|
||||
user_id=dispatch_data["user_id"],
|
||||
timestamp=dispatch_data["timestamp"],
|
||||
request_id=dispatch_data["request_id"],
|
||||
subscription=mock_plugin_subscription,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
trigger_processing_tasks.dispatch_triggered_workflows_async,
|
||||
"delay",
|
||||
staticmethod(_fake_delay),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"hello": "world"})
|
||||
|
||||
assert response.status_code == 202
|
||||
assert dispatched_payloads, "Plugin trigger should enqueue workflow dispatch payload"
|
||||
assert debug_events, "Plugin trigger should dispatch debug events"
|
||||
dispatched_event_names = {event["event"].name for event in debug_events}
|
||||
assert dispatched_event_names == {"created", "updated"}
|
||||
|
||||
|
||||
def test_webhook_debug_dispatches_event(
|
||||
test_client_with_containers: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Webhook single-step debug should dispatch debug event and be pollable."""
|
||||
tenant, account = tenant_and_account
|
||||
webhook_node_id = "webhook-debug-node"
|
||||
graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK)
|
||||
draft_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=graph_json,
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(draft_workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
app_id=app_model.id,
|
||||
node_id=webhook_node_id,
|
||||
tenant_id=tenant.id,
|
||||
webhook_id=WEBHOOK_ID_DEBUG,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(webhook_trigger)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
debug_events: list[dict[str, Any]] = []
|
||||
original_dispatch = TriggerDebugEventBus.dispatch
|
||||
monkeypatch.setattr(
|
||||
"controllers.trigger.webhook.TriggerDebugEventBus.dispatch",
|
||||
lambda **kwargs: (debug_events.append(kwargs), original_dispatch(**kwargs))[1],
|
||||
)
|
||||
|
||||
# Listener polls first to enter waiting pool
|
||||
poller = WebhookTriggerDebugEventPoller(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=app_model.id,
|
||||
node_config=draft_workflow.get_node_config_by_id(webhook_node_id),
|
||||
node_id=webhook_node_id,
|
||||
)
|
||||
assert poller.poll() is None
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
f"/triggers/webhook-debug/{webhook_trigger.webhook_id}",
|
||||
json={"foo": "bar"},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert debug_events, "Debug event should be sent to event bus"
|
||||
# Second poll should get the event
|
||||
event = poller.poll()
|
||||
assert event is not None
|
||||
assert event.workflow_args["inputs"]["webhook_body"]["foo"] == "bar"
|
||||
assert debug_events[0]["pool_key"].endswith(f":{app_model.id}:{webhook_node_id}")
|
||||
|
||||
|
||||
def test_plugin_single_step_debug_flow(
|
||||
flask_app_with_containers: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Plugin single-step debug: listen -> dispatch event -> poller receives and returns variables."""
|
||||
tenant_id = "tenant-1"
|
||||
app_id = "app-1"
|
||||
user_id = "user-1"
|
||||
node_id = "plugin-node"
|
||||
provider_id = "langgenius/provider-1/provider-1"
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_PLUGIN.value,
|
||||
"title": "plugin",
|
||||
"plugin_id": "plugin-1",
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"provider_id": provider_id,
|
||||
"event_name": "created",
|
||||
"subscription_id": "sub-1",
|
||||
"parameters": {},
|
||||
},
|
||||
}
|
||||
# Start listening
|
||||
poller = PluginTriggerDebugEventPoller(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
node_config=node_config,
|
||||
node_id=node_id,
|
||||
)
|
||||
assert poller.poll() is None
|
||||
|
||||
from core.trigger.debug.events import build_plugin_pool_key
|
||||
|
||||
pool_key = build_plugin_pool_key(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_id="sub-1",
|
||||
name="created",
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
tenant_id=tenant_id,
|
||||
event=PluginTriggerDebugEvent(
|
||||
timestamp=int(time.time()),
|
||||
user_id=user_id,
|
||||
name="created",
|
||||
request_id="req-1",
|
||||
subscription_id="sub-1",
|
||||
provider_id="provider-1",
|
||||
),
|
||||
pool_key=pool_key,
|
||||
)
|
||||
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.invoke_trigger_event",
|
||||
staticmethod(
|
||||
lambda **_kwargs: TriggerInvokeEventResponse(
|
||||
variables={"echo": "pong"},
|
||||
cancelled=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
event = poller.poll()
|
||||
assert event is not None
|
||||
assert event.workflow_args["inputs"]["echo"] == "pong"
|
||||
|
||||
|
||||
def test_schedule_trigger_creates_trigger_log(
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Schedule trigger execution should create WorkflowTriggerLog in database."""
|
||||
from tasks import workflow_schedule_tasks
|
||||
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
# Create published workflow with schedule trigger node
|
||||
schedule_node_id = "schedule-node"
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": schedule_node_id,
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_SCHEDULE.value,
|
||||
"title": "schedule",
|
||||
"mode": "cron",
|
||||
"cron_expression": "0 9 * * *",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
},
|
||||
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
|
||||
],
|
||||
"edges": [{"source": schedule_node_id, "target": "answer-1", "sourceHandle": "success"}],
|
||||
}
|
||||
published_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.version_from_datetime(naive_utc_now()),
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(published_workflow)
|
||||
app_model.workflow_id = published_workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create schedule plan
|
||||
plan = WorkflowSchedulePlan(
|
||||
app_id=app_model.id,
|
||||
node_id=schedule_node_id,
|
||||
tenant_id=tenant.id,
|
||||
cron_expression="0 9 * * *",
|
||||
timezone="UTC",
|
||||
next_run_at=naive_utc_now() - timedelta(minutes=1),
|
||||
)
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
node_id=schedule_node_id,
|
||||
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
title="schedule",
|
||||
)
|
||||
db_session_with_containers.add_all([plan, app_trigger])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock AsyncWorkflowService to create WorkflowTriggerLog
|
||||
def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace:
|
||||
log = WorkflowTriggerLog(
|
||||
tenant_id=trigger_data.tenant_id,
|
||||
app_id=trigger_data.app_id,
|
||||
workflow_id=published_workflow.id,
|
||||
root_node_id=trigger_data.root_node_id,
|
||||
trigger_metadata="{}",
|
||||
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
|
||||
workflow_run_id=None,
|
||||
outputs=None,
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
inputs=json.dumps(dict(trigger_data.inputs)),
|
||||
status=WorkflowTriggerStatus.SUCCEEDED,
|
||||
error="",
|
||||
queue_name="schedule_executor",
|
||||
celery_task_id="celery-schedule-test",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
session.add(log)
|
||||
session.commit()
|
||||
return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test")
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_schedule_tasks.AsyncWorkflowService,
|
||||
"trigger_workflow_async",
|
||||
_fake_trigger_workflow_async,
|
||||
)
|
||||
|
||||
# Mock quota to avoid rate limiting
|
||||
from enums import quota_type
|
||||
|
||||
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
|
||||
|
||||
# Execute schedule trigger
|
||||
workflow_schedule_tasks.run_schedule_trigger(plan.id)
|
||||
|
||||
# Verify WorkflowTriggerLog was created
|
||||
db_session_with_containers.expire_all()
|
||||
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
|
||||
assert logs, "Schedule trigger should create WorkflowTriggerLog"
|
||||
assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE
|
||||
assert logs[0].root_node_id == schedule_node_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "frequency", "visual_config", "cron_expression", "expected_cron"),
|
||||
[
|
||||
# Visual mode: hourly
|
||||
("visual", "hourly", {"on_minute": 30}, None, "30 * * * *"),
|
||||
# Visual mode: daily
|
||||
("visual", "daily", {"time": "3:00 PM"}, None, "0 15 * * *"),
|
||||
# Visual mode: weekly
|
||||
("visual", "weekly", {"time": "9:00 AM", "weekdays": ["mon", "wed", "fri"]}, None, "0 9 * * 1,3,5"),
|
||||
# Visual mode: monthly
|
||||
("visual", "monthly", {"time": "10:30 AM", "monthly_days": [1, 15]}, None, "30 10 1,15 * *"),
|
||||
# Cron mode: direct expression
|
||||
("cron", None, None, "*/5 * * * *", "*/5 * * * *"),
|
||||
],
|
||||
)
|
||||
def test_schedule_visual_cron_conversion(
|
||||
mode: str,
|
||||
frequency: str | None,
|
||||
visual_config: dict[str, Any] | None,
|
||||
cron_expression: str | None,
|
||||
expected_cron: str,
|
||||
) -> None:
|
||||
"""Schedule visual config should correctly convert to cron expression."""
|
||||
|
||||
node_config: dict[str, Any] = {
|
||||
"id": "schedule-node",
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_SCHEDULE.value,
|
||||
"mode": mode,
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
|
||||
if mode == "visual":
|
||||
node_config["data"]["frequency"] = frequency
|
||||
node_config["data"]["visual_config"] = visual_config
|
||||
else:
|
||||
node_config["data"]["cron_expression"] = cron_expression
|
||||
|
||||
config = ScheduleService.to_schedule_config(node_config)
|
||||
|
||||
assert config.cron_expression == expected_cron, f"Expected {expected_cron}, got {config.cron_expression}"
|
||||
assert config.timezone == "UTC"
|
||||
assert config.node_id == "schedule-node"
|
||||
|
||||
|
||||
def test_plugin_trigger_full_chain_with_db_verification(
|
||||
test_client_with_containers: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Plugin trigger should create WorkflowTriggerLog and WorkflowPluginTrigger records."""
|
||||
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
# Create published workflow with plugin trigger node
|
||||
plugin_node_id = "plugin-trigger-node"
|
||||
provider_id = "langgenius/test-provider/test-provider"
|
||||
subscription_id = "sub-plugin-test"
|
||||
endpoint_id = "2cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
|
||||
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": plugin_node_id,
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_PLUGIN.value,
|
||||
"title": "plugin",
|
||||
"plugin_id": "test-plugin",
|
||||
"plugin_unique_identifier": "test-plugin",
|
||||
"provider_id": provider_id,
|
||||
"event_name": "test_event",
|
||||
"subscription_id": subscription_id,
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
|
||||
],
|
||||
"edges": [{"source": plugin_node_id, "target": "answer-1", "sourceHandle": "success"}],
|
||||
}
|
||||
published_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.version_from_datetime(naive_utc_now()),
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(published_workflow)
|
||||
app_model.workflow_id = published_workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create trigger subscription
|
||||
subscription = TriggerSubscription(
|
||||
name="test-subscription",
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
credentials={"token": "test-secret"},
|
||||
credential_type="api-key",
|
||||
)
|
||||
db_session_with_containers.add(subscription)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update subscription_id to match the created subscription
|
||||
graph["nodes"][0]["data"]["subscription_id"] = subscription.id
|
||||
published_workflow.graph = json.dumps(graph)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create WorkflowPluginTrigger
|
||||
plugin_trigger = WorkflowPluginTrigger(
|
||||
app_id=app_model.id,
|
||||
tenant_id=tenant.id,
|
||||
node_id=plugin_node_id,
|
||||
provider_id=provider_id,
|
||||
event_name="test_event",
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
node_id=plugin_node_id,
|
||||
trigger_type=AppTriggerType.TRIGGER_PLUGIN,
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
title="plugin",
|
||||
)
|
||||
db_session_with_containers.add_all([plugin_trigger, app_trigger])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Track dispatched data
|
||||
dispatched_data: list[dict[str, Any]] = []
|
||||
|
||||
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
|
||||
dispatch_data = {
|
||||
"user_id": "end-user",
|
||||
"tenant_id": tenant.id,
|
||||
"endpoint_id": _endpoint_id,
|
||||
"provider_id": provider_id,
|
||||
"subscription_id": subscription.id,
|
||||
"timestamp": int(time.time()),
|
||||
"events": ["test_event"],
|
||||
"request_id": f"req-{_endpoint_id}",
|
||||
}
|
||||
dispatched_data.append(dispatch_data)
|
||||
return Response("ok", status=202)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.process_endpoint",
|
||||
staticmethod(_fake_process_endpoint),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"test": "data"})
|
||||
|
||||
assert response.status_code == 202
|
||||
assert dispatched_data, "Plugin trigger should dispatch event data"
|
||||
assert dispatched_data[0]["subscription_id"] == subscription.id
|
||||
assert dispatched_data[0]["events"] == ["test_event"]
|
||||
|
||||
# Verify database records exist
|
||||
db_session_with_containers.expire_all()
|
||||
plugin_triggers = (
|
||||
db_session_with_containers.query(WorkflowPluginTrigger)
|
||||
.filter_by(app_id=app_model.id, node_id=plugin_node_id)
|
||||
.all()
|
||||
)
|
||||
assert plugin_triggers, "WorkflowPluginTrigger record should exist"
|
||||
assert plugin_triggers[0].provider_id == provider_id
|
||||
assert plugin_triggers[0].event_name == "test_event"
|
||||
|
||||
|
||||
def test_plugin_debug_via_http_endpoint(
|
||||
test_client_with_containers: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Plugin single-step debug via HTTP endpoint should dispatch debug event and be pollable."""
|
||||
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
provider_id = "langgenius/debug-provider/debug-provider"
|
||||
endpoint_id = "3cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
|
||||
event_name = "debug_event"
|
||||
|
||||
# Create subscription
|
||||
subscription = TriggerSubscription(
|
||||
name="debug-subscription",
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
credentials={"token": "debug-secret"},
|
||||
credential_type="api-key",
|
||||
)
|
||||
db_session_with_containers.add(subscription)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create plugin trigger node config
|
||||
node_id = "plugin-debug-node"
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_PLUGIN.value,
|
||||
"title": "plugin-debug",
|
||||
"plugin_id": "debug-plugin",
|
||||
"plugin_unique_identifier": "debug-plugin",
|
||||
"provider_id": provider_id,
|
||||
"event_name": event_name,
|
||||
"subscription_id": subscription.id,
|
||||
"parameters": {},
|
||||
},
|
||||
}
|
||||
|
||||
# Start listening with poller
|
||||
|
||||
poller = PluginTriggerDebugEventPoller(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=app_model.id,
|
||||
node_config=node_config,
|
||||
node_id=node_id,
|
||||
)
|
||||
assert poller.poll() is None, "First poll should return None (waiting)"
|
||||
|
||||
# Track debug events dispatched
|
||||
debug_events: list[dict[str, Any]] = []
|
||||
original_dispatch = TriggerDebugEventBus.dispatch
|
||||
|
||||
def _tracking_dispatch(**kwargs: Any) -> int:
|
||||
debug_events.append(kwargs)
|
||||
return original_dispatch(**kwargs)
|
||||
|
||||
monkeypatch.setattr(TriggerDebugEventBus, "dispatch", staticmethod(_tracking_dispatch))
|
||||
|
||||
# Mock process_endpoint to trigger debug event dispatch
|
||||
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
|
||||
# Simulate what happens inside process_endpoint + dispatch_triggered_workflows_async
|
||||
pool_key = build_plugin_pool_key(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
name=event_name,
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
tenant_id=tenant.id,
|
||||
event=PluginTriggerDebugEvent(
|
||||
timestamp=int(time.time()),
|
||||
user_id="end-user",
|
||||
name=event_name,
|
||||
request_id=f"req-{_endpoint_id}",
|
||||
subscription_id=subscription.id,
|
||||
provider_id=provider_id,
|
||||
),
|
||||
pool_key=pool_key,
|
||||
)
|
||||
return Response("ok", status=202)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.process_endpoint",
|
||||
staticmethod(_fake_process_endpoint),
|
||||
)
|
||||
|
||||
# Call HTTP endpoint
|
||||
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"debug": "payload"})
|
||||
|
||||
assert response.status_code == 202
|
||||
assert debug_events, "Debug event should be dispatched via HTTP endpoint"
|
||||
assert debug_events[0]["event"].name == event_name
|
||||
|
||||
# Mock invoke_trigger_event for poller
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.invoke_trigger_event",
|
||||
staticmethod(
|
||||
lambda **_kwargs: TriggerInvokeEventResponse(
|
||||
variables={"http_debug": "success"},
|
||||
cancelled=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Second poll should receive the event
|
||||
event = poller.poll()
|
||||
assert event is not None, "Poller should receive debug event after HTTP trigger"
|
||||
assert event.workflow_args["inputs"]["http_debug"] == "success"
|
||||
|
|
@ -26,16 +26,29 @@ redis_mock.hgetall = MagicMock(return_value={})
|
|||
redis_mock.hdel = MagicMock()
|
||||
redis_mock.incr = MagicMock(return_value=1)
|
||||
|
||||
# Ensure OpenDAL fs writes to tmp to avoid polluting workspace
|
||||
os.environ.setdefault("OPENDAL_SCHEME", "fs")
|
||||
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
|
||||
# Add the API directory to Python path to ensure proper imports
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, PROJECT_DIR)
|
||||
|
||||
# apply the mock to the Redis client in the Flask app
|
||||
from extensions import ext_redis
|
||||
|
||||
redis_patcher = patch.object(ext_redis, "redis_client", redis_mock)
|
||||
redis_patcher.start()
|
||||
|
||||
def _patch_redis_clients_on_loaded_modules():
|
||||
"""Ensure any module-level redis_client references point to the shared redis_mock."""
|
||||
|
||||
import sys
|
||||
|
||||
for module in list(sys.modules.values()):
|
||||
if module is None:
|
||||
continue
|
||||
if hasattr(module, "redis_client"):
|
||||
module.redis_client = redis_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -49,6 +62,15 @@ def _provide_app_context(app: Flask):
|
|||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_redis_clients():
|
||||
"""Patch redis_client to MagicMock only for unit test executions."""
|
||||
|
||||
with patch.object(ext_redis, "redis_client", redis_mock):
|
||||
_patch_redis_clients_on_loaded_modules()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_redis_mock():
|
||||
"""reset the Redis mock before each test"""
|
||||
|
|
@ -63,3 +85,20 @@ def reset_redis_mock():
|
|||
redis_mock.hgetall.return_value = {}
|
||||
redis_mock.hdel.return_value = None
|
||||
redis_mock.incr.return_value = 1
|
||||
|
||||
# Keep any imported modules pointing at the mock between tests
|
||||
_patch_redis_clients_on_loaded_modules()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_secret_key():
|
||||
"""Ensure SECRET_KEY-dependent logic sees an empty config value by default."""
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
original = dify_config.SECRET_KEY
|
||||
dify_config.SECRET_KEY = ""
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dify_config.SECRET_KEY = original
|
||||
|
|
|
|||
|
|
@ -0,0 +1,407 @@
|
|||
"""Final working unit tests for admin endpoints - tests business logic directly."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.console.admin import InsertExploreAppPayload
|
||||
from models.model import App, RecommendedApp
|
||||
|
||||
|
||||
class TestInsertExploreAppPayload:
|
||||
"""Test InsertExploreAppPayload validation."""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test creating payload with valid data."""
|
||||
payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"desc": "Test app description",
|
||||
"copyright": "© 2024 Test Company",
|
||||
"privacy_policy": "https://example.com/privacy",
|
||||
"custom_disclaimer": "Custom disclaimer text",
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
payload = InsertExploreAppPayload.model_validate(payload_data)
|
||||
|
||||
assert payload.app_id == payload_data["app_id"]
|
||||
assert payload.desc == payload_data["desc"]
|
||||
assert payload.copyright == payload_data["copyright"]
|
||||
assert payload.privacy_policy == payload_data["privacy_policy"]
|
||||
assert payload.custom_disclaimer == payload_data["custom_disclaimer"]
|
||||
assert payload.language == payload_data["language"]
|
||||
assert payload.category == payload_data["category"]
|
||||
assert payload.position == payload_data["position"]
|
||||
|
||||
def test_minimal_payload(self):
|
||||
"""Test creating payload with only required fields."""
|
||||
payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
payload = InsertExploreAppPayload.model_validate(payload_data)
|
||||
|
||||
assert payload.app_id == payload_data["app_id"]
|
||||
assert payload.desc is None
|
||||
assert payload.copyright is None
|
||||
assert payload.privacy_policy is None
|
||||
assert payload.custom_disclaimer is None
|
||||
assert payload.language == payload_data["language"]
|
||||
assert payload.category == payload_data["category"]
|
||||
assert payload.position == payload_data["position"]
|
||||
|
||||
def test_invalid_language(self):
|
||||
"""Test payload with invalid language code."""
|
||||
payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "invalid-lang",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
InsertExploreAppPayload.model_validate(payload_data)
|
||||
|
||||
|
||||
class TestAdminRequiredDecorator:
|
||||
"""Test admin_required decorator."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Mock dify_config
|
||||
self.dify_config_patcher = patch("controllers.console.admin.dify_config")
|
||||
self.mock_dify_config = self.dify_config_patcher.start()
|
||||
self.mock_dify_config.ADMIN_API_KEY = "test-admin-key"
|
||||
|
||||
# Mock extract_access_token
|
||||
self.token_patcher = patch("controllers.console.admin.extract_access_token")
|
||||
self.mock_extract_token = self.token_patcher.start()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up test fixtures."""
|
||||
self.dify_config_patcher.stop()
|
||||
self.token_patcher.stop()
|
||||
|
||||
def test_admin_required_success(self):
|
||||
"""Test successful admin authentication."""
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
@admin_required
|
||||
def test_view():
|
||||
return {"success": True}
|
||||
|
||||
self.mock_extract_token.return_value = "test-admin-key"
|
||||
result = test_view()
|
||||
assert result["success"] is True
|
||||
|
||||
def test_admin_required_invalid_token(self):
|
||||
"""Test admin_required with invalid token."""
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
@admin_required
|
||||
def test_view():
|
||||
return {"success": True}
|
||||
|
||||
self.mock_extract_token.return_value = "wrong-key"
|
||||
with pytest.raises(Unauthorized, match="API key is invalid"):
|
||||
test_view()
|
||||
|
||||
def test_admin_required_no_api_key_configured(self):
|
||||
"""Test admin_required when no API key is configured."""
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
self.mock_dify_config.ADMIN_API_KEY = None
|
||||
|
||||
@admin_required
|
||||
def test_view():
|
||||
return {"success": True}
|
||||
|
||||
with pytest.raises(Unauthorized, match="API key is invalid"):
|
||||
test_view()
|
||||
|
||||
def test_admin_required_missing_authorization_header(self):
|
||||
"""Test admin_required with missing authorization header."""
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
@admin_required
|
||||
def test_view():
|
||||
return {"success": True}
|
||||
|
||||
self.mock_extract_token.return_value = None
|
||||
with pytest.raises(Unauthorized, match="Authorization header is missing"):
|
||||
test_view()
|
||||
|
||||
|
||||
class TestExploreAppBusinessLogicDirect:
|
||||
"""Test the core business logic of explore app management directly."""
|
||||
|
||||
def test_data_fusion_logic(self):
|
||||
"""Test the data fusion logic between payload and site data."""
|
||||
# Test cases for different data scenarios
|
||||
test_cases = [
|
||||
{
|
||||
"name": "site_data_overrides_payload",
|
||||
"payload": {"desc": "Payload desc", "copyright": "Payload copyright"},
|
||||
"site": {"description": "Site desc", "copyright": "Site copyright"},
|
||||
"expected": {
|
||||
"desc": "Site desc",
|
||||
"copyright": "Site copyright",
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "payload_used_when_no_site",
|
||||
"payload": {"desc": "Payload desc", "copyright": "Payload copyright"},
|
||||
"site": None,
|
||||
"expected": {
|
||||
"desc": "Payload desc",
|
||||
"copyright": "Payload copyright",
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "empty_defaults_when_no_data",
|
||||
"payload": {},
|
||||
"site": None,
|
||||
"expected": {"desc": "", "copyright": "", "privacy_policy": "", "custom_disclaimer": ""},
|
||||
},
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
# Simulate the data fusion logic
|
||||
payload_desc = case["payload"].get("desc")
|
||||
payload_copyright = case["payload"].get("copyright")
|
||||
payload_privacy_policy = case["payload"].get("privacy_policy")
|
||||
payload_custom_disclaimer = case["payload"].get("custom_disclaimer")
|
||||
|
||||
if case["site"]:
|
||||
site_desc = case["site"].get("description")
|
||||
site_copyright = case["site"].get("copyright")
|
||||
site_privacy_policy = case["site"].get("privacy_policy")
|
||||
site_custom_disclaimer = case["site"].get("custom_disclaimer")
|
||||
|
||||
# Site data takes precedence
|
||||
desc = site_desc or payload_desc or ""
|
||||
copyright = site_copyright or payload_copyright or ""
|
||||
privacy_policy = site_privacy_policy or payload_privacy_policy or ""
|
||||
custom_disclaimer = site_custom_disclaimer or payload_custom_disclaimer or ""
|
||||
else:
|
||||
# Use payload data or empty defaults
|
||||
desc = payload_desc or ""
|
||||
copyright = payload_copyright or ""
|
||||
privacy_policy = payload_privacy_policy or ""
|
||||
custom_disclaimer = payload_custom_disclaimer or ""
|
||||
|
||||
result = {
|
||||
"desc": desc,
|
||||
"copyright": copyright,
|
||||
"privacy_policy": privacy_policy,
|
||||
"custom_disclaimer": custom_disclaimer,
|
||||
}
|
||||
|
||||
assert result == case["expected"], f"Failed test case: {case['name']}"
|
||||
|
||||
def test_app_visibility_logic(self):
|
||||
"""Test that apps are made public when added to explore list."""
|
||||
# Create a mock app
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.is_public = False
|
||||
|
||||
# Simulate the business logic
|
||||
mock_app.is_public = True
|
||||
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_recommended_app_creation_logic(self):
|
||||
"""Test the creation of RecommendedApp objects."""
|
||||
app_id = str(uuid.uuid4())
|
||||
payload_data = {
|
||||
"app_id": app_id,
|
||||
"desc": "Test app description",
|
||||
"copyright": "© 2024 Test Company",
|
||||
"privacy_policy": "https://example.com/privacy",
|
||||
"custom_disclaimer": "Custom disclaimer",
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
# Simulate the creation logic
|
||||
recommended_app = Mock(spec=RecommendedApp)
|
||||
recommended_app.app_id = payload_data["app_id"]
|
||||
recommended_app.description = payload_data["desc"]
|
||||
recommended_app.copyright = payload_data["copyright"]
|
||||
recommended_app.privacy_policy = payload_data["privacy_policy"]
|
||||
recommended_app.custom_disclaimer = payload_data["custom_disclaimer"]
|
||||
recommended_app.language = payload_data["language"]
|
||||
recommended_app.category = payload_data["category"]
|
||||
recommended_app.position = payload_data["position"]
|
||||
|
||||
# Verify the data
|
||||
assert recommended_app.app_id == app_id
|
||||
assert recommended_app.description == "Test app description"
|
||||
assert recommended_app.copyright == "© 2024 Test Company"
|
||||
assert recommended_app.privacy_policy == "https://example.com/privacy"
|
||||
assert recommended_app.custom_disclaimer == "Custom disclaimer"
|
||||
assert recommended_app.language == "en-US"
|
||||
assert recommended_app.category == "Productivity"
|
||||
assert recommended_app.position == 1
|
||||
|
||||
def test_recommended_app_update_logic(self):
|
||||
"""Test the update logic for existing RecommendedApp objects."""
|
||||
mock_recommended_app = Mock(spec=RecommendedApp)
|
||||
|
||||
update_data = {
|
||||
"desc": "Updated description",
|
||||
"copyright": "© 2024 Updated",
|
||||
"language": "fr-FR",
|
||||
"category": "Tools",
|
||||
"position": 2,
|
||||
}
|
||||
|
||||
# Simulate the update logic
|
||||
mock_recommended_app.description = update_data["desc"]
|
||||
mock_recommended_app.copyright = update_data["copyright"]
|
||||
mock_recommended_app.language = update_data["language"]
|
||||
mock_recommended_app.category = update_data["category"]
|
||||
mock_recommended_app.position = update_data["position"]
|
||||
|
||||
# Verify the updates
|
||||
assert mock_recommended_app.description == "Updated description"
|
||||
assert mock_recommended_app.copyright == "© 2024 Updated"
|
||||
assert mock_recommended_app.language == "fr-FR"
|
||||
assert mock_recommended_app.category == "Tools"
|
||||
assert mock_recommended_app.position == 2
|
||||
|
||||
def test_app_not_found_error_logic(self):
|
||||
"""Test error handling when app is not found."""
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Simulate app lookup returning None
|
||||
found_app = None
|
||||
|
||||
# Test the error condition
|
||||
if not found_app:
|
||||
with pytest.raises(NotFound, match=f"App '{app_id}' is not found"):
|
||||
raise NotFound(f"App '{app_id}' is not found")
|
||||
|
||||
def test_recommended_app_not_found_error_logic(self):
|
||||
"""Test error handling when recommended app is not found for deletion."""
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Simulate recommended app lookup returning None
|
||||
found_recommended_app = None
|
||||
|
||||
# Test the error condition
|
||||
if not found_recommended_app:
|
||||
with pytest.raises(NotFound, match=f"App '{app_id}' is not found in the explore list"):
|
||||
raise NotFound(f"App '{app_id}' is not found in the explore list")
|
||||
|
||||
def test_database_session_usage_patterns(self):
|
||||
"""Test the expected database session usage patterns."""
|
||||
# Mock session usage patterns
|
||||
mock_session = Mock()
|
||||
|
||||
# Test session.add pattern
|
||||
mock_recommended_app = Mock(spec=RecommendedApp)
|
||||
mock_session.add(mock_recommended_app)
|
||||
mock_session.commit()
|
||||
|
||||
# Verify session was used correctly
|
||||
mock_session.add.assert_called_once_with(mock_recommended_app)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Test session.delete pattern
|
||||
mock_recommended_app_to_delete = Mock(spec=RecommendedApp)
|
||||
mock_session.delete(mock_recommended_app_to_delete)
|
||||
mock_session.commit()
|
||||
|
||||
# Verify delete pattern
|
||||
mock_session.delete.assert_called_once_with(mock_recommended_app_to_delete)
|
||||
|
||||
def test_payload_validation_integration(self):
|
||||
"""Test payload validation in the context of the business logic."""
|
||||
# Test valid payload
|
||||
valid_payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"desc": "Test app description",
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
# This should succeed
|
||||
payload = InsertExploreAppPayload.model_validate(valid_payload_data)
|
||||
assert payload.app_id == valid_payload_data["app_id"]
|
||||
|
||||
# Test invalid payload
|
||||
invalid_payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "invalid-lang", # This should fail validation
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
# This should raise an exception
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
InsertExploreAppPayload.model_validate(invalid_payload_data)
|
||||
|
||||
|
||||
class TestExploreAppDataHandling:
|
||||
"""Test specific data handling scenarios."""
|
||||
|
||||
def test_uuid_validation(self):
|
||||
"""Test UUID validation and handling."""
|
||||
# Test valid UUID
|
||||
valid_uuid = str(uuid.uuid4())
|
||||
|
||||
# This should be a valid UUID
|
||||
assert uuid.UUID(valid_uuid) is not None
|
||||
|
||||
# Test invalid UUID
|
||||
invalid_uuid = "not-a-valid-uuid"
|
||||
|
||||
# This should raise a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
uuid.UUID(invalid_uuid)
|
||||
|
||||
def test_language_validation(self):
|
||||
"""Test language validation against supported languages."""
|
||||
from constants.languages import supported_language
|
||||
|
||||
# Test supported language
|
||||
assert supported_language("en-US") == "en-US"
|
||||
assert supported_language("fr-FR") == "fr-FR"
|
||||
|
||||
# Test unsupported language
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
supported_language("invalid-lang")
|
||||
|
||||
def test_response_formatting(self):
|
||||
"""Test API response formatting."""
|
||||
# Test success responses
|
||||
create_response = {"result": "success"}
|
||||
update_response = {"result": "success"}
|
||||
delete_response = None # 204 No Content returns None
|
||||
|
||||
assert create_response["result"] == "success"
|
||||
assert update_response["result"] == "success"
|
||||
assert delete_response is None
|
||||
|
||||
# Test status codes
|
||||
create_status = 201 # Created
|
||||
update_status = 200 # OK
|
||||
delete_status = 204 # No Content
|
||||
|
||||
assert create_status == 201
|
||||
assert update_status == 200
|
||||
assert delete_status == 204
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.service_api.app.completion import ChatRequestPayload
|
||||
|
||||
|
||||
def test_chat_request_payload_accepts_blank_conversation_id():
|
||||
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": ""})
|
||||
|
||||
assert payload.conversation_id is None
|
||||
|
||||
|
||||
def test_chat_request_payload_validates_uuid():
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": conversation_id})
|
||||
|
||||
assert payload.conversation_id == conversation_id
|
||||
|
||||
|
||||
def test_chat_request_payload_rejects_invalid_uuid():
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.explore.conversation import ConversationRenamePayload as ConsolePayload
|
||||
from controllers.service_api.app.conversation import ConversationRenamePayload as ServicePayload
|
||||
|
||||
|
||||
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
|
||||
def test_payload_allows_auto_generate_without_name(payload_cls):
|
||||
payload = payload_cls.model_validate({"auto_generate": True})
|
||||
|
||||
assert payload.auto_generate is True
|
||||
assert payload.name is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
|
||||
@pytest.mark.parametrize("value", [None, "", " "])
|
||||
def test_payload_requires_name_when_not_auto_generate(payload_cls, value):
|
||||
with pytest.raises(ValidationError):
|
||||
payload_cls.model_validate({"name": value, "auto_generate": False})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue