mirror of https://github.com/langgenius/dify.git
Merge branch 'fix/polyfill-toSplice' into deploy/dev
This commit is contained in:
commit
d1b4bb247a
|
|
@ -0,0 +1,322 @@
|
|||
---
|
||||
name: frontend-testing
|
||||
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
|
||||
---
|
||||
|
||||
# Dify Frontend Testing Skill
|
||||
|
||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||
|
||||
## When to Apply This Skill
|
||||
|
||||
Apply this skill when the user:
|
||||
|
||||
- Asks to **write tests** for a component, hook, or utility
|
||||
- Asks to **review existing tests** for completeness
|
||||
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
|
||||
- Requests **test coverage** improvement
|
||||
- Uses `pnpm analyze-component` output as context
|
||||
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
||||
- Wants to understand **testing patterns** in the Dify codebase
|
||||
|
||||
**Do NOT apply** when:
|
||||
|
||||
- User is asking about backend/API tests (Python/pytest)
|
||||
- User is asking about E2E tests (Playwright/Cypress)
|
||||
- User is only asking conceptual questions without code context
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Tech Stack
|
||||
|
||||
| Tool | Version | Purpose |
|
||||
|------|---------|---------|
|
||||
| Vitest | 4.0.16 | Test runner |
|
||||
| React Testing Library | 16.0 | Component testing |
|
||||
| jsdom | - | Test environment |
|
||||
| nock | 14.0 | HTTP mocking |
|
||||
| TypeScript | 5.x | Type safety |
|
||||
|
||||
### Key Commands
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pnpm test
|
||||
|
||||
# Watch mode
|
||||
pnpm test:watch
|
||||
|
||||
# Run specific file
|
||||
pnpm test -- path/to/file.spec.tsx
|
||||
|
||||
# Generate coverage report
|
||||
pnpm test -- --coverage
|
||||
|
||||
# Analyze component complexity
|
||||
pnpm analyze-component <path>
|
||||
|
||||
# Review existing test
|
||||
pnpm analyze-component <path> --review
|
||||
```
|
||||
|
||||
### File Naming
|
||||
|
||||
- Test files: `ComponentName.spec.tsx` (same directory as component)
|
||||
- Integration tests: `web/__tests__/` directory
|
||||
|
||||
## Test Structure Template
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import Component from './index'
|
||||
|
||||
// ✅ Import real project components (DO NOT mock these)
|
||||
// import Loading from '@/app/components/base/loading'
|
||||
// import { ChildComponent } from './child-component'
|
||||
|
||||
// ✅ Mock external dependencies only
|
||||
vi.mock('@/service/api')
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
usePathname: () => '/test',
|
||||
}))
|
||||
|
||||
// Shared state for mocks (if needed)
|
||||
let mockSharedState = false
|
||||
|
||||
describe('ComponentName', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
||||
mockSharedState = false // ✅ Reset shared state
|
||||
})
|
||||
|
||||
// Rendering tests (REQUIRED)
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
// Arrange
|
||||
const props = { title: 'Test' }
|
||||
|
||||
// Act
|
||||
render(<Component {...props} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Test')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Props tests (REQUIRED)
|
||||
describe('Props', () => {
|
||||
it('should apply custom className', () => {
|
||||
render(<Component className="custom" />)
|
||||
expect(screen.getByRole('button')).toHaveClass('custom')
|
||||
})
|
||||
})
|
||||
|
||||
// User Interactions
|
||||
describe('User Interactions', () => {
|
||||
it('should handle click events', () => {
|
||||
const handleClick = vi.fn()
|
||||
render(<Component onClick={handleClick} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// Edge Cases (REQUIRED)
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null data', () => {
|
||||
render(<Component data={null} />)
|
||||
expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
render(<Component items={[]} />)
|
||||
expect(screen.getByText(/empty/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Testing Workflow (CRITICAL)
|
||||
|
||||
### ⚠️ Incremental Approach Required
|
||||
|
||||
**NEVER generate all test files at once.** For complex components or multi-file directories:
|
||||
|
||||
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
|
||||
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
|
||||
1. **Verify before proceeding**: Do NOT continue to next file until current passes
|
||||
|
||||
```
|
||||
For each file:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Write test │
|
||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
||||
│ 3. PASS? → Mark complete, next file │
|
||||
│ FAIL? → Fix first, then continue │
|
||||
└────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Complexity-Based Order
|
||||
|
||||
Process in this order for multi-file testing:
|
||||
|
||||
1. 🟢 Utility functions (simplest)
|
||||
1. 🟢 Custom hooks
|
||||
1. 🟡 Simple components (presentational)
|
||||
1. 🟡 Medium components (state, effects)
|
||||
1. 🔴 Complex components (API, routing)
|
||||
1. 🔴 Integration tests (index files - last)
|
||||
|
||||
### When to Refactor First
|
||||
|
||||
- **Complexity > 50**: Break into smaller pieces before testing
|
||||
- **500+ lines**: Consider splitting before testing
|
||||
- **Many dependencies**: Extract logic into hooks first
|
||||
|
||||
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Path-Level Testing (Directory Testing)
|
||||
|
||||
When assigned to test a directory/path, test **ALL content** within that path:
|
||||
|
||||
- Test all components, hooks, utilities in the directory (not just `index` file)
|
||||
- Use incremental approach: one file at a time, verify each before proceeding
|
||||
- Goal: 100% coverage of ALL files in the directory
|
||||
|
||||
### Integration Testing First
|
||||
|
||||
**Prefer integration testing** when writing tests for a directory:
|
||||
|
||||
- ✅ **Import real project components** directly (including base components and siblings)
|
||||
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
|
||||
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
|
||||
- ❌ **DO NOT mock** sibling/child components in the same directory
|
||||
|
||||
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
|
||||
|
||||
## Core Principles
|
||||
|
||||
### 1. AAA Pattern (Arrange-Act-Assert)
|
||||
|
||||
Every test should clearly separate:
|
||||
|
||||
- **Arrange**: Setup test data and render component
|
||||
- **Act**: Perform user actions
|
||||
- **Assert**: Verify expected outcomes
|
||||
|
||||
### 2. Black-Box Testing
|
||||
|
||||
- Test observable behavior, not implementation details
|
||||
- Use semantic queries (getByRole, getByLabelText)
|
||||
- Avoid testing internal state directly
|
||||
- **Prefer pattern matching over hardcoded strings** in assertions:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid: hardcoded text assertions
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
|
||||
// ✅ Better: role-based queries
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
|
||||
// ✅ Better: pattern matching
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### 3. Single Behavior Per Test
|
||||
|
||||
Each test verifies ONE user-observable behavior:
|
||||
|
||||
```typescript
|
||||
// ✅ Good: One behavior
|
||||
it('should disable button when loading', () => {
|
||||
render(<Button loading />)
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
|
||||
// ❌ Bad: Multiple behaviors
|
||||
it('should handle loading state', () => {
|
||||
render(<Button loading />)
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button')).toHaveClass('loading')
|
||||
})
|
||||
```
|
||||
|
||||
### 4. Semantic Naming
|
||||
|
||||
Use `should <behavior> when <condition>`:
|
||||
|
||||
```typescript
|
||||
it('should show error message when validation fails')
|
||||
it('should call onSubmit when form is valid')
|
||||
it('should disable input when isReadOnly is true')
|
||||
```
|
||||
|
||||
## Required Test Scenarios
|
||||
|
||||
### Always Required (All Components)
|
||||
|
||||
1. **Rendering**: Component renders without crashing
|
||||
1. **Props**: Required props, optional props, default values
|
||||
1. **Edge Cases**: null, undefined, empty values, boundary conditions
|
||||
|
||||
### Conditional (When Present)
|
||||
|
||||
| Feature | Test Focus |
|
||||
|---------|-----------|
|
||||
| `useState` | Initial state, transitions, cleanup |
|
||||
| `useEffect` | Execution, dependencies, cleanup |
|
||||
| Event handlers | All onClick, onChange, onSubmit, keyboard |
|
||||
| API calls | Loading, success, error states |
|
||||
| Routing | Navigation, params, query strings |
|
||||
| `useCallback`/`useMemo` | Referential equality |
|
||||
| Context | Provider values, consumer behavior |
|
||||
| Forms | Validation, submission, error display |
|
||||
|
||||
## Coverage Goals (Per File)
|
||||
|
||||
For each test file generated, aim for:
|
||||
|
||||
- ✅ **100%** function coverage
|
||||
- ✅ **100%** statement coverage
|
||||
- ✅ **>95%** branch coverage
|
||||
- ✅ **>95%** line coverage
|
||||
|
||||
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
|
||||
|
||||
## Detailed Guides
|
||||
|
||||
For more detailed information, refer to:
|
||||
|
||||
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||
- `references/mocking.md` - Mock patterns and best practices
|
||||
- `references/async-testing.md` - Async operations and API calls
|
||||
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||
- `references/common-patterns.md` - Frequently used testing patterns
|
||||
- `references/checklist.md` - Test generation checklist and validation steps
|
||||
|
||||
## Authoritative References
|
||||
|
||||
### Primary Specification (MUST follow)
|
||||
|
||||
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
|
||||
|
||||
### Reference Examples in Codebase
|
||||
|
||||
- `web/utils/classnames.spec.ts` - Utility function tests
|
||||
- `web/app/components/base/button/index.spec.tsx` - Component tests
|
||||
- `web/__mocks__/provider-context.ts` - Mock factory example
|
||||
|
||||
### Project Configuration
|
||||
|
||||
- `web/vitest.config.ts` - Vitest configuration
|
||||
- `web/vitest.setup.ts` - Test environment setup
|
||||
- `web/testing/analyze-component.js` - Component analysis tool
|
||||
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
/**
|
||||
* Test Template for React Components
|
||||
*
|
||||
* WHY THIS STRUCTURE?
|
||||
* - Organized sections make tests easy to navigate and maintain
|
||||
* - Mocks at top ensure consistent test isolation
|
||||
* - Factory functions reduce duplication and improve readability
|
||||
* - describe blocks group related scenarios for better debugging
|
||||
*
|
||||
* INSTRUCTIONS:
|
||||
* 1. Replace `ComponentName` with your component name
|
||||
* 2. Update import path
|
||||
* 3. Add/remove test sections based on component features (use analyze-component)
|
||||
* 4. Follow AAA pattern: Arrange → Act → Assert
|
||||
*
|
||||
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
|
||||
*/
|
||||
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
// import ComponentName from './index'
|
||||
|
||||
// ============================================================================
|
||||
// Mocks
|
||||
// ============================================================================
|
||||
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
|
||||
// They run BEFORE imports, so keep them before component imports.
|
||||
|
||||
// i18n (automatically mocked)
|
||||
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||
// No explicit mock needed - it returns translation keys as-is
|
||||
// Override only if custom translations are required:
|
||||
// vi.mock('react-i18next', () => ({
|
||||
// useTranslation: () => ({
|
||||
// t: (key: string) => {
|
||||
// const customTranslations: Record<string, string> = {
|
||||
// 'my.custom.key': 'Custom Translation',
|
||||
// }
|
||||
// return customTranslations[key] || key
|
||||
// },
|
||||
// }),
|
||||
// }))
|
||||
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = vi.fn()
|
||||
// vi.mock('next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
||||
// API services (if component fetches data)
|
||||
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
||||
// vi.mock('@/service/api')
|
||||
// import * as api from '@/service/api'
|
||||
// const mockedApi = vi.mocked(api)
|
||||
|
||||
// Shared mock state (for portal/dropdown components)
|
||||
// WHY: Portal components like PortalToFollowElem need shared state between
|
||||
// parent and child mocks to correctly simulate open/close behavior
|
||||
// let mockOpenState = false
|
||||
|
||||
// ============================================================================
|
||||
// Test Data Factories
|
||||
// ============================================================================
|
||||
// WHY FACTORIES?
|
||||
// - Avoid hard-coded test data scattered across tests
|
||||
// - Easy to create variations with overrides
|
||||
// - Type-safe when using actual types from source
|
||||
// - Single source of truth for default test values
|
||||
|
||||
// const createMockProps = (overrides = {}) => ({
|
||||
// // Default props that make component render successfully
|
||||
// ...overrides,
|
||||
// })
|
||||
|
||||
// const createMockItem = (overrides = {}) => ({
|
||||
// id: 'item-1',
|
||||
// name: 'Test Item',
|
||||
// ...overrides,
|
||||
// })
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
// const renderComponent = (props = {}) => {
|
||||
// return render(<ComponentName {...createMockProps(props)} />)
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('ComponentName', () => {
|
||||
// WHY beforeEach with clearAllMocks?
|
||||
// - Ensures each test starts with clean slate
|
||||
// - Prevents mock call history from leaking between tests
|
||||
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
||||
// mockOpenState = false
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Rendering Tests (REQUIRED - Every component MUST have these)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Catches import errors, missing providers, and basic render issues
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
// Arrange - Setup data and mocks
|
||||
// const props = createMockProps()
|
||||
|
||||
// Act - Render the component
|
||||
// render(<ComponentName {...props} />)
|
||||
|
||||
// Assert - Verify expected output
|
||||
// Prefer getByRole for accessibility; it's what users "see"
|
||||
// expect(screen.getByRole('...')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with default props', () => {
|
||||
// WHY: Verifies component works without optional props
|
||||
// render(<ComponentName />)
|
||||
// expect(screen.getByText('...')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Props Tests (REQUIRED - Every component MUST test prop behavior)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Props are the component's API contract. Test them thoroughly.
|
||||
describe('Props', () => {
|
||||
it('should apply custom className', () => {
|
||||
// WHY: Common pattern in Dify - components should merge custom classes
|
||||
// render(<ComponentName className="custom-class" />)
|
||||
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
|
||||
})
|
||||
|
||||
it('should use default values for optional props', () => {
|
||||
// WHY: Verifies TypeScript defaults work at runtime
|
||||
// render(<ComponentName />)
|
||||
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// User Interactions (if component has event handlers - on*, handle*)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Event handlers are core functionality. Test from user's perspective.
|
||||
describe('User Interactions', () => {
|
||||
it('should call onClick when clicked', async () => {
|
||||
// WHY userEvent over fireEvent?
|
||||
// - userEvent simulates real user behavior (focus, hover, then click)
|
||||
// - fireEvent is lower-level, doesn't trigger all browser events
|
||||
// const user = userEvent.setup()
|
||||
// const handleClick = vi.fn()
|
||||
// render(<ComponentName onClick={handleClick} />)
|
||||
//
|
||||
// await user.click(screen.getByRole('button'))
|
||||
//
|
||||
// expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onChange when value changes', async () => {
|
||||
// const user = userEvent.setup()
|
||||
// const handleChange = vi.fn()
|
||||
// render(<ComponentName onChange={handleChange} />)
|
||||
//
|
||||
// await user.type(screen.getByRole('textbox'), 'new value')
|
||||
//
|
||||
// expect(handleChange).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// State Management (if component uses useState/useReducer)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Test state through observable UI changes, not internal state values
|
||||
describe('State Management', () => {
|
||||
it('should update state on interaction', async () => {
|
||||
// WHY test via UI, not state?
|
||||
// - State is implementation detail; UI is what users see
|
||||
// - If UI works correctly, state must be correct
|
||||
// const user = userEvent.setup()
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// // Initial state - verify what user sees
|
||||
// expect(screen.getByText('Initial')).toBeInTheDocument()
|
||||
//
|
||||
// // Trigger state change via user action
|
||||
// await user.click(screen.getByRole('button'))
|
||||
//
|
||||
// // New state - verify UI updated
|
||||
// expect(screen.getByText('Updated')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Async operations have 3 states users experience: loading, success, error
|
||||
describe('Async Operations', () => {
|
||||
it('should show loading state', () => {
|
||||
// WHY never-resolving promise?
|
||||
// - Keeps component in loading state for assertion
|
||||
// - Alternative: use fake timers
|
||||
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
// WHY waitFor?
|
||||
// - Component updates asynchronously after fetch resolves
|
||||
// - waitFor retries assertion until it passes or times out
|
||||
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(screen.getByText('Item 1')).toBeInTheDocument()
|
||||
// })
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
// })
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Real-world data is messy. Components must handle:
|
||||
// - Null/undefined from API failures or optional fields
|
||||
// - Empty arrays/strings from user clearing data
|
||||
// - Boundary values (0, MAX_INT, special characters)
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null value', () => {
|
||||
// WHY test null specifically?
|
||||
// - API might return null for missing data
|
||||
// - Prevents "Cannot read property of null" in production
|
||||
// render(<ComponentName value={null} />)
|
||||
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined value', () => {
|
||||
// WHY test undefined separately from null?
|
||||
// - TypeScript treats them differently
|
||||
// - Optional props are undefined, not null
|
||||
// render(<ComponentName value={undefined} />)
|
||||
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
// WHY: Empty state often needs special UI (e.g., "No items yet")
|
||||
// render(<ComponentName items={[]} />)
|
||||
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty string', () => {
|
||||
// WHY: Empty strings are truthy in JS but visually empty
|
||||
// render(<ComponentName text="" />)
|
||||
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Accessibility (optional but recommended for Dify's enterprise users)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Dify has enterprise customers who may require accessibility compliance
|
||||
describe('Accessibility', () => {
|
||||
it('should have accessible name', () => {
|
||||
// WHY getByRole with name?
|
||||
// - Tests that screen readers can identify the element
|
||||
// - Enforces proper labeling practices
|
||||
// render(<ComponentName label="Test Label" />)
|
||||
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should support keyboard navigation', async () => {
|
||||
// WHY: Some users can't use a mouse
|
||||
// const user = userEvent.setup()
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await user.tab()
|
||||
// expect(screen.getByRole('button')).toHaveFocus()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
/**
|
||||
* Test Template for Custom Hooks
|
||||
*
|
||||
* Instructions:
|
||||
* 1. Replace `useHookName` with your hook name
|
||||
* 2. Update import path
|
||||
* 3. Add/remove test sections based on hook features
|
||||
*/
|
||||
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
// import { useHookName } from './use-hook-name'
|
||||
|
||||
// ============================================================================
|
||||
// Mocks
|
||||
// ============================================================================
|
||||
|
||||
// API services (if hook fetches data)
|
||||
// vi.mock('@/service/api')
|
||||
// import * as api from '@/service/api'
|
||||
// const mockedApi = vi.mocked(api)
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
// Wrapper for hooks that need context
|
||||
// const createWrapper = (contextValue = {}) => {
|
||||
// return ({ children }: { children: React.ReactNode }) => (
|
||||
// <SomeContext.Provider value={contextValue}>
|
||||
// {children}
|
||||
// </SomeContext.Provider>
|
||||
// )
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('useHookName', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Initial State
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Initial State', () => {
|
||||
it('should return initial state', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// expect(result.current.value).toBe(initialValue)
|
||||
// expect(result.current.isLoading).toBe(false)
|
||||
})
|
||||
|
||||
it('should accept initial value from props', () => {
|
||||
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
|
||||
//
|
||||
// expect(result.current.value).toBe('custom')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// State Updates
|
||||
// --------------------------------------------------------------------------
|
||||
describe('State Updates', () => {
|
||||
it('should update value when setValue is called', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('new value')
|
||||
// })
|
||||
//
|
||||
// expect(result.current.value).toBe('new value')
|
||||
})
|
||||
|
||||
it('should reset to initial value', () => {
|
||||
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('changed')
|
||||
// })
|
||||
// expect(result.current.value).toBe('changed')
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.reset()
|
||||
// })
|
||||
// expect(result.current.value).toBe('initial')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Async Operations
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Async Operations', () => {
|
||||
it('should fetch data on mount', async () => {
|
||||
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
|
||||
//
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// // Initially loading
|
||||
// expect(result.current.isLoading).toBe(true)
|
||||
//
|
||||
// // Wait for data
|
||||
// await waitFor(() => {
|
||||
// expect(result.current.isLoading).toBe(false)
|
||||
// })
|
||||
//
|
||||
// expect(result.current.data).toEqual({ data: 'test' })
|
||||
})
|
||||
|
||||
it('should handle fetch error', async () => {
|
||||
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
//
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(result.current.error).toBeTruthy()
|
||||
// })
|
||||
//
|
||||
// expect(result.current.error?.message).toBe('Network error')
|
||||
})
|
||||
|
||||
it('should refetch when dependency changes', async () => {
|
||||
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
|
||||
//
|
||||
// const { result, rerender } = renderHook(
|
||||
// ({ id }) => useHookName(id),
|
||||
// { initialProps: { id: '1' } }
|
||||
// )
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
|
||||
// })
|
||||
//
|
||||
// rerender({ id: '2' })
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
|
||||
// })
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Side Effects
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Side Effects', () => {
|
||||
it('should call callback when value changes', () => {
|
||||
// const callback = vi.fn()
|
||||
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('new value')
|
||||
// })
|
||||
//
|
||||
// expect(callback).toHaveBeenCalledWith('new value')
|
||||
})
|
||||
|
||||
it('should cleanup on unmount', () => {
|
||||
// const cleanup = vi.fn()
|
||||
// vi.spyOn(window, 'addEventListener')
|
||||
// vi.spyOn(window, 'removeEventListener')
|
||||
//
|
||||
// const { unmount } = renderHook(() => useHookName())
|
||||
//
|
||||
// expect(window.addEventListener).toHaveBeenCalled()
|
||||
//
|
||||
// unmount()
|
||||
//
|
||||
// expect(window.removeEventListener).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null input', () => {
|
||||
// const { result } = renderHook(() => useHookName(null))
|
||||
//
|
||||
// expect(result.current.value).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle rapid updates', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('1')
|
||||
// result.current.setValue('2')
|
||||
// result.current.setValue('3')
|
||||
// })
|
||||
//
|
||||
// expect(result.current.value).toBe('3')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// With Context (if hook uses context)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('With Context', () => {
|
||||
it('should use context value', () => {
|
||||
// const wrapper = createWrapper({ someValue: 'context-value' })
|
||||
// const { result } = renderHook(() => useHookName(), { wrapper })
|
||||
//
|
||||
// expect(result.current.contextValue).toBe('context-value')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Test Template for Utility Functions
|
||||
*
|
||||
* Instructions:
|
||||
* 1. Replace `utilityFunction` with your function name
|
||||
* 2. Update import path
|
||||
* 3. Use test.each for data-driven tests
|
||||
*/
|
||||
|
||||
// import { utilityFunction } from './utility'
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('utilityFunction', () => {
|
||||
// --------------------------------------------------------------------------
|
||||
// Basic Functionality
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Basic Functionality', () => {
|
||||
it('should return expected result for valid input', () => {
|
||||
// expect(utilityFunction('input')).toBe('expected-output')
|
||||
})
|
||||
|
||||
it('should handle multiple arguments', () => {
|
||||
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Data-Driven Tests
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Input/Output Mapping', () => {
|
||||
test.each([
|
||||
// [input, expected]
|
||||
['input1', 'output1'],
|
||||
['input2', 'output2'],
|
||||
['input3', 'output3'],
|
||||
])('should return %s for input %s', (input, expected) => {
|
||||
// expect(utilityFunction(input)).toBe(expected)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty string', () => {
|
||||
// expect(utilityFunction('')).toBe('')
|
||||
})
|
||||
|
||||
it('should handle null', () => {
|
||||
// expect(utilityFunction(null)).toBe(null)
|
||||
// or
|
||||
// expect(() => utilityFunction(null)).toThrow()
|
||||
})
|
||||
|
||||
it('should handle undefined', () => {
|
||||
// expect(utilityFunction(undefined)).toBe(undefined)
|
||||
// or
|
||||
// expect(() => utilityFunction(undefined)).toThrow()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
// expect(utilityFunction([])).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle empty object', () => {
|
||||
// expect(utilityFunction({})).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Boundary Conditions
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Boundary Conditions', () => {
|
||||
it('should handle minimum value', () => {
|
||||
// expect(utilityFunction(0)).toBe(0)
|
||||
})
|
||||
|
||||
it('should handle maximum value', () => {
|
||||
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
|
||||
})
|
||||
|
||||
it('should handle negative numbers', () => {
|
||||
// expect(utilityFunction(-1)).toBe(...)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Type Coercion (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Type Handling', () => {
|
||||
it('should handle numeric string', () => {
|
||||
// expect(utilityFunction('123')).toBe(123)
|
||||
})
|
||||
|
||||
it('should handle boolean', () => {
|
||||
// expect(utilityFunction(true)).toBe(...)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Error Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Error Handling', () => {
|
||||
it('should throw for invalid input', () => {
|
||||
// expect(() => utilityFunction('invalid')).toThrow('Error message')
|
||||
})
|
||||
|
||||
it('should throw with specific error type', () => {
|
||||
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Complex Objects (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Object Handling', () => {
|
||||
it('should preserve object structure', () => {
|
||||
// const input = { a: 1, b: 2 }
|
||||
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
|
||||
})
|
||||
|
||||
it('should handle nested objects', () => {
|
||||
// const input = { nested: { deep: 'value' } }
|
||||
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
|
||||
})
|
||||
|
||||
it('should not mutate input', () => {
|
||||
// const input = { a: 1 }
|
||||
// const inputCopy = { ...input }
|
||||
// utilityFunction(input)
|
||||
// expect(input).toEqual(inputCopy)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Array Handling (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Array Handling', () => {
|
||||
it('should process all elements', () => {
|
||||
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
|
||||
})
|
||||
|
||||
it('should handle single element array', () => {
|
||||
// expect(utilityFunction([1])).toEqual([2])
|
||||
})
|
||||
|
||||
it('should preserve order', () => {
|
||||
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,345 @@
|
|||
# Async Testing Guide
|
||||
|
||||
## Core Async Patterns
|
||||
|
||||
### 1. waitFor - Wait for Condition
|
||||
|
||||
```typescript
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
|
||||
it('should load and display data', async () => {
|
||||
render(<DataComponent />)
|
||||
|
||||
// Wait for element to appear
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should hide loading spinner after load', async () => {
|
||||
render(<DataComponent />)
|
||||
|
||||
// Wait for element to disappear
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 2. findBy\* - Async Queries
|
||||
|
||||
```typescript
|
||||
it('should show user name after fetch', async () => {
|
||||
render(<UserProfile />)
|
||||
|
||||
// findBy returns a promise, auto-waits up to 1000ms
|
||||
const userName = await screen.findByText('John Doe')
|
||||
expect(userName).toBeInTheDocument()
|
||||
|
||||
// findByRole with options
|
||||
const button = await screen.findByRole('button', { name: /submit/i })
|
||||
expect(button).toBeEnabled()
|
||||
})
|
||||
```
|
||||
|
||||
### 3. userEvent for Async Interactions
|
||||
|
||||
```typescript
|
||||
import userEvent from '@testing-library/user-event'
|
||||
|
||||
it('should submit form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = vi.fn()
|
||||
|
||||
render(<Form onSubmit={onSubmit} />)
|
||||
|
||||
// userEvent methods are async
|
||||
await user.type(screen.getByLabelText('Email'), 'test@example.com')
|
||||
await user.click(screen.getByRole('button', { name: /submit/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Fake Timers
|
||||
|
||||
### When to Use Fake Timers
|
||||
|
||||
- Testing components with `setTimeout`/`setInterval`
|
||||
- Testing debounce/throttle behavior
|
||||
- Testing animations or delayed transitions
|
||||
- Testing polling or retry logic
|
||||
|
||||
### Basic Fake Timer Setup
|
||||
|
||||
```typescript
|
||||
describe('Debounced Search', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should debounce search input', async () => {
|
||||
const onSearch = vi.fn()
|
||||
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
||||
|
||||
// Type in the input
|
||||
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
|
||||
|
||||
// Search not called immediately
|
||||
expect(onSearch).not.toHaveBeenCalled()
|
||||
|
||||
// Advance timers
|
||||
vi.advanceTimersByTime(300)
|
||||
|
||||
// Now search is called
|
||||
expect(onSearch).toHaveBeenCalledWith('query')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Fake Timers with Async Code
|
||||
|
||||
```typescript
|
||||
it('should retry on failure', async () => {
|
||||
vi.useFakeTimers()
|
||||
const fetchData = vi.fn()
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockResolvedValueOnce({ data: 'success' })
|
||||
|
||||
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
|
||||
|
||||
// First call fails
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Advance timer for retry
|
||||
vi.advanceTimersByTime(1000)
|
||||
|
||||
// Second call succeeds
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(2)
|
||||
expect(screen.getByText('success')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
```
|
||||
|
||||
### Common Fake Timer Utilities
|
||||
|
||||
```typescript
|
||||
// Run all pending timers
|
||||
vi.runAllTimers()
|
||||
|
||||
// Run only pending timers (not new ones created during execution)
|
||||
vi.runOnlyPendingTimers()
|
||||
|
||||
// Advance by specific time
|
||||
vi.advanceTimersByTime(1000)
|
||||
|
||||
// Get current fake time
|
||||
Date.now()
|
||||
|
||||
// Clear all timers
|
||||
vi.clearAllTimers()
|
||||
```
|
||||
|
||||
## API Testing Patterns
|
||||
|
||||
### Loading → Success → Error States
|
||||
|
||||
```typescript
|
||||
describe('DataFetcher', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should show loading state', () => {
|
||||
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
|
||||
const item1 = await screen.findByText('Item 1')
|
||||
const item2 = await screen.findByText('Item 2')
|
||||
expect(item1).toBeInTheDocument()
|
||||
expect(item2).toBeInTheDocument()
|
||||
|
||||
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should retry on error', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
|
||||
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Item 1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Mutations
|
||||
|
||||
```typescript
|
||||
it('should submit form and show success', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
|
||||
|
||||
render(<CreateItemForm />)
|
||||
|
||||
await user.type(screen.getByLabelText('Name'), 'New Item')
|
||||
await user.click(screen.getByRole('button', { name: /create/i }))
|
||||
|
||||
// Button should be disabled during submission
|
||||
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
|
||||
})
|
||||
```
|
||||
|
||||
## useEffect Testing
|
||||
|
||||
### Testing Effect Execution
|
||||
|
||||
```typescript
|
||||
it('should fetch data on mount', async () => {
|
||||
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||
|
||||
render(<ComponentWithEffect fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Effect Dependencies
|
||||
|
||||
```typescript
|
||||
it('should refetch when id changes', async () => {
|
||||
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||
|
||||
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledWith('1')
|
||||
})
|
||||
|
||||
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledWith('2')
|
||||
expect(fetchData).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Effect Cleanup
|
||||
|
||||
```typescript
|
||||
it('should cleanup subscription on unmount', () => {
|
||||
const subscribe = vi.fn()
|
||||
const unsubscribe = vi.fn()
|
||||
subscribe.mockReturnValue(unsubscribe)
|
||||
|
||||
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
||||
|
||||
expect(subscribe).toHaveBeenCalledTimes(1)
|
||||
|
||||
unmount()
|
||||
|
||||
expect(unsubscribe).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
```
|
||||
|
||||
## Common Async Pitfalls
|
||||
|
||||
### ❌ Don't: Forget to await
|
||||
|
||||
```typescript
|
||||
// Bad - test may pass even if assertion fails
|
||||
it('should load data', () => {
|
||||
render(<Component />)
|
||||
waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Good - properly awaited
|
||||
it('should load data', async () => {
|
||||
render(<Component />)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### ❌ Don't: Use multiple assertions in single waitFor
|
||||
|
||||
```typescript
|
||||
// Bad - if first assertion fails, won't know about second
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Title')).toBeInTheDocument()
|
||||
expect(screen.getByText('Description')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Good - separate waitFor or use findBy
|
||||
const title = await screen.findByText('Title')
|
||||
const description = await screen.findByText('Description')
|
||||
expect(title).toBeInTheDocument()
|
||||
expect(description).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### ❌ Don't: Mix fake timers with real async
|
||||
|
||||
```typescript
|
||||
// Bad - fake timers don't work well with real Promises
|
||||
vi.useFakeTimers()
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
}) // May timeout!
|
||||
|
||||
// Good - use runAllTimers or advanceTimersByTime
|
||||
vi.useFakeTimers()
|
||||
render(<Component />)
|
||||
vi.runAllTimers()
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
```
|
||||
|
|
@ -0,0 +1,205 @@
|
|||
# Test Generation Checklist
|
||||
|
||||
Use this checklist when generating or reviewing tests for Dify frontend components.
|
||||
|
||||
## Pre-Generation
|
||||
|
||||
- [ ] Read the component source code completely
|
||||
- [ ] Identify component type (component, hook, utility, page)
|
||||
- [ ] Run `pnpm analyze-component <path>` if available
|
||||
- [ ] Note complexity score and features detected
|
||||
- [ ] Check for existing tests in the same directory
|
||||
- [ ] **Identify ALL files in the directory** that need testing (not just index)
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
|
||||
|
||||
- [ ] **NEVER generate all tests at once** - process one file at a time
|
||||
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
|
||||
- [ ] Create a todo list to track progress before starting
|
||||
- [ ] For EACH file: write → run test → verify pass → then next
|
||||
- [ ] **DO NOT proceed** to next file until current one passes
|
||||
|
||||
### Path-Level Coverage
|
||||
|
||||
- [ ] **Test ALL files** in the assigned directory/path
|
||||
- [ ] List all components, hooks, utilities that need coverage
|
||||
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
|
||||
|
||||
### Complexity Assessment
|
||||
|
||||
- [ ] Run `pnpm analyze-component <path>` for complexity score
|
||||
- [ ] **Complexity > 50**: Consider refactoring before testing
|
||||
- [ ] **500+ lines**: Consider splitting before testing
|
||||
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
|
||||
|
||||
### Integration vs Mocking
|
||||
|
||||
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
- [ ] Import real project components instead of mocking
|
||||
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
|
||||
- [ ] Prefer integration testing when using single spec file
|
||||
|
||||
## Required Test Sections
|
||||
|
||||
### All Components MUST Have
|
||||
|
||||
- [ ] **Rendering tests** - Component renders without crashing
|
||||
- [ ] **Props tests** - Required props, optional props, default values
|
||||
- [ ] **Edge cases** - null, undefined, empty values, boundaries
|
||||
|
||||
### Conditional Sections (Add When Feature Present)
|
||||
|
||||
| Feature | Add Tests For |
|
||||
|---------|---------------|
|
||||
| `useState` | Initial state, transitions, cleanup |
|
||||
| `useEffect` | Execution, dependencies, cleanup |
|
||||
| Event handlers | onClick, onChange, onSubmit, keyboard |
|
||||
| API calls | Loading, success, error states |
|
||||
| Routing | Navigation, params, query strings |
|
||||
| `useCallback`/`useMemo` | Referential equality |
|
||||
| Context | Provider values, consumer behavior |
|
||||
| Forms | Validation, submission, error display |
|
||||
|
||||
## Code Quality Checklist
|
||||
|
||||
### Structure
|
||||
|
||||
- [ ] Uses `describe` blocks to group related tests
|
||||
- [ ] Test names follow `should <behavior> when <condition>` pattern
|
||||
- [ ] AAA pattern (Arrange-Act-Assert) is clear
|
||||
- [ ] Comments explain complex test scenarios
|
||||
|
||||
### Mocks
|
||||
|
||||
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||
- [ ] Shared mock state reset in `beforeEach`
|
||||
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
|
||||
- [ ] Router mocks match actual Next.js API
|
||||
- [ ] Mocks reflect actual component conditional behavior
|
||||
- [ ] Only mock: API services, complex context providers, third-party libs
|
||||
|
||||
### Queries
|
||||
|
||||
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
|
||||
- [ ] Use `queryBy*` for absence assertions
|
||||
- [ ] Use `findBy*` for async elements
|
||||
- [ ] `getByTestId` only as last resort
|
||||
|
||||
### Async
|
||||
|
||||
- [ ] All async tests use `async/await`
|
||||
- [ ] `waitFor` wraps async assertions
|
||||
- [ ] Fake timers properly setup/teardown
|
||||
- [ ] No floating promises
|
||||
|
||||
### TypeScript
|
||||
|
||||
- [ ] No `any` types without justification
|
||||
- [ ] Mock data uses actual types from source
|
||||
- [ ] Factory functions have proper return types
|
||||
|
||||
## Coverage Goals (Per File)
|
||||
|
||||
For the current file being tested:
|
||||
|
||||
- [ ] 100% function coverage
|
||||
- [ ] 100% statement coverage
|
||||
- [ ] >95% branch coverage
|
||||
- [ ] >95% line coverage
|
||||
|
||||
## Post-Generation (Per File)
|
||||
|
||||
**Run these checks after EACH test file, not just at the end:**
|
||||
|
||||
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
|
||||
- [ ] Fix any failures immediately
|
||||
- [ ] Mark file as complete in todo list
|
||||
- [ ] Only then proceed to next file
|
||||
|
||||
### After All Files Complete
|
||||
|
||||
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test -- --coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
|
||||
## Common Issues to Watch
|
||||
|
||||
### False Positives
|
||||
|
||||
```typescript
|
||||
// ❌ Mock doesn't match actual behavior
|
||||
vi.mock('./Component', () => () => <div>Mocked</div>)
|
||||
|
||||
// ✅ Mock matches actual conditional logic
|
||||
vi.mock('./Component', () => ({ isOpen }: any) =>
|
||||
isOpen ? <div>Content</div> : null
|
||||
)
|
||||
```
|
||||
|
||||
### State Leakage
|
||||
|
||||
```typescript
|
||||
// ❌ Shared state not reset
|
||||
let mockState = false
|
||||
vi.mock('./useHook', () => () => mockState)
|
||||
|
||||
// ✅ Reset in beforeEach
|
||||
beforeEach(() => {
|
||||
mockState = false
|
||||
})
|
||||
```
|
||||
|
||||
### Async Race Conditions
|
||||
|
||||
```typescript
|
||||
// ❌ Not awaited
|
||||
it('loads data', () => {
|
||||
render(<Component />)
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// ✅ Properly awaited
|
||||
it('loads data', async () => {
|
||||
render(<Component />)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Missing Edge Cases
|
||||
|
||||
Always test these scenarios:
|
||||
|
||||
- `null` / `undefined` inputs
|
||||
- Empty strings / arrays / objects
|
||||
- Boundary values (0, -1, MAX_INT)
|
||||
- Error states
|
||||
- Loading states
|
||||
- Disabled states
|
||||
|
||||
## Quick Commands
|
||||
|
||||
```bash
|
||||
# Run specific test
|
||||
pnpm test -- path/to/file.spec.tsx
|
||||
|
||||
# Run with coverage
|
||||
pnpm test -- --coverage path/to/file.spec.tsx
|
||||
|
||||
# Watch mode
|
||||
pnpm test:watch -- path/to/file.spec.tsx
|
||||
|
||||
# Update snapshots (use sparingly)
|
||||
pnpm test -- -u path/to/file.spec.tsx
|
||||
|
||||
# Analyze component
|
||||
pnpm analyze-component path/to/component.tsx
|
||||
|
||||
# Review existing test
|
||||
pnpm analyze-component path/to/component.tsx --review
|
||||
```
|
||||
|
|
@ -0,0 +1,449 @@
|
|||
# Common Testing Patterns
|
||||
|
||||
## Query Priority
|
||||
|
||||
Use queries in this order (most to least preferred):
|
||||
|
||||
```typescript
|
||||
// 1. getByRole - Most recommended (accessibility)
|
||||
screen.getByRole('button', { name: /submit/i })
|
||||
screen.getByRole('textbox', { name: /email/i })
|
||||
screen.getByRole('heading', { level: 1 })
|
||||
|
||||
// 2. getByLabelText - Form fields
|
||||
screen.getByLabelText('Email address')
|
||||
screen.getByLabelText(/password/i)
|
||||
|
||||
// 3. getByPlaceholderText - When no label
|
||||
screen.getByPlaceholderText('Search...')
|
||||
|
||||
// 4. getByText - Non-interactive elements
|
||||
screen.getByText('Welcome to Dify')
|
||||
screen.getByText(/loading/i)
|
||||
|
||||
// 5. getByDisplayValue - Current input value
|
||||
screen.getByDisplayValue('current value')
|
||||
|
||||
// 6. getByAltText - Images
|
||||
screen.getByAltText('Company logo')
|
||||
|
||||
// 7. getByTitle - Tooltip elements
|
||||
screen.getByTitle('Close')
|
||||
|
||||
// 8. getByTestId - Last resort only!
|
||||
screen.getByTestId('custom-element')
|
||||
```
|
||||
|
||||
## Event Handling Patterns
|
||||
|
||||
### Click Events
|
||||
|
||||
```typescript
|
||||
// Basic click
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// With userEvent (preferred for realistic interaction)
|
||||
const user = userEvent.setup()
|
||||
await user.click(screen.getByRole('button'))
|
||||
|
||||
// Double click
|
||||
await user.dblClick(screen.getByRole('button'))
|
||||
|
||||
// Right click
|
||||
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
|
||||
```
|
||||
|
||||
### Form Input
|
||||
|
||||
```typescript
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Type in input
|
||||
await user.type(screen.getByRole('textbox'), 'Hello World')
|
||||
|
||||
// Clear and type
|
||||
await user.clear(screen.getByRole('textbox'))
|
||||
await user.type(screen.getByRole('textbox'), 'New value')
|
||||
|
||||
// Select option
|
||||
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
|
||||
|
||||
// Check checkbox
|
||||
await user.click(screen.getByRole('checkbox'))
|
||||
|
||||
// Upload file
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
```
|
||||
|
||||
### Keyboard Events
|
||||
|
||||
```typescript
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Press Enter
|
||||
await user.keyboard('{Enter}')
|
||||
|
||||
// Press Escape
|
||||
await user.keyboard('{Escape}')
|
||||
|
||||
// Keyboard shortcut
|
||||
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
|
||||
|
||||
// Tab navigation
|
||||
await user.tab()
|
||||
|
||||
// Arrow keys
|
||||
await user.keyboard('{ArrowDown}')
|
||||
await user.keyboard('{ArrowUp}')
|
||||
```
|
||||
|
||||
## Component State Testing
|
||||
|
||||
### Testing State Transitions
|
||||
|
||||
```typescript
|
||||
describe('Counter', () => {
|
||||
it('should increment count', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<Counter initialCount={0} />)
|
||||
|
||||
// Initial state
|
||||
expect(screen.getByText('Count: 0')).toBeInTheDocument()
|
||||
|
||||
// Trigger transition
|
||||
await user.click(screen.getByRole('button', { name: /increment/i }))
|
||||
|
||||
// New state
|
||||
expect(screen.getByText('Count: 1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Controlled Components
|
||||
|
||||
```typescript
|
||||
describe('ControlledInput', () => {
|
||||
it('should call onChange with new value', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleChange = vi.fn()
|
||||
|
||||
render(<ControlledInput value="" onChange={handleChange} />)
|
||||
|
||||
await user.type(screen.getByRole('textbox'), 'a')
|
||||
|
||||
expect(handleChange).toHaveBeenCalledWith('a')
|
||||
})
|
||||
|
||||
it('should display controlled value', () => {
|
||||
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
|
||||
|
||||
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Conditional Rendering Testing
|
||||
|
||||
```typescript
|
||||
describe('ConditionalComponent', () => {
|
||||
it('should show loading state', () => {
|
||||
render(<DataDisplay isLoading={true} data={null} />)
|
||||
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show error state', () => {
|
||||
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
|
||||
|
||||
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data when loaded', () => {
|
||||
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
|
||||
|
||||
expect(screen.getByText('Test')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show empty state when no data', () => {
|
||||
render(<DataDisplay isLoading={false} data={[]} />)
|
||||
|
||||
expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## List Rendering Testing
|
||||
|
||||
```typescript
|
||||
describe('ItemList', () => {
|
||||
const items = [
|
||||
{ id: '1', name: 'Item 1' },
|
||||
{ id: '2', name: 'Item 2' },
|
||||
{ id: '3', name: 'Item 3' },
|
||||
]
|
||||
|
||||
it('should render all items', () => {
|
||||
render(<ItemList items={items} />)
|
||||
|
||||
expect(screen.getAllByRole('listitem')).toHaveLength(3)
|
||||
items.forEach(item => {
|
||||
expect(screen.getByText(item.name)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle item selection', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(<ItemList items={items} onSelect={onSelect} />)
|
||||
|
||||
await user.click(screen.getByText('Item 2'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(items[1])
|
||||
})
|
||||
|
||||
it('should handle empty list', () => {
|
||||
render(<ItemList items={[]} />)
|
||||
|
||||
expect(screen.getByText(/no items/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Modal/Dialog Testing
|
||||
|
||||
```typescript
|
||||
describe('Modal', () => {
|
||||
it('should not render when closed', () => {
|
||||
render(<Modal isOpen={false} onClose={vi.fn()} />)
|
||||
|
||||
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render when open', () => {
|
||||
render(<Modal isOpen={true} onClose={vi.fn()} />)
|
||||
|
||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onClose when clicking overlay', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClose = vi.fn()
|
||||
|
||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||
|
||||
await user.click(screen.getByTestId('modal-overlay'))
|
||||
|
||||
expect(handleClose).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onClose when pressing Escape', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClose = vi.fn()
|
||||
|
||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||
|
||||
await user.keyboard('{Escape}')
|
||||
|
||||
expect(handleClose).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should trap focus inside modal', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<Modal isOpen={true} onClose={vi.fn()}>
|
||||
<button>First</button>
|
||||
<button>Second</button>
|
||||
</Modal>
|
||||
)
|
||||
|
||||
// Focus should cycle within modal
|
||||
await user.tab()
|
||||
expect(screen.getByText('First')).toHaveFocus()
|
||||
|
||||
await user.tab()
|
||||
expect(screen.getByText('Second')).toHaveFocus()
|
||||
|
||||
await user.tab()
|
||||
expect(screen.getByText('First')).toHaveFocus() // Cycles back
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Form Testing
|
||||
|
||||
```typescript
|
||||
describe('LoginForm', () => {
|
||||
it('should submit valid form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = vi.fn()
|
||||
|
||||
render(<LoginForm onSubmit={onSubmit} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
|
||||
await user.type(screen.getByLabelText(/password/i), 'password123')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(onSubmit).toHaveBeenCalledWith({
|
||||
email: 'test@example.com',
|
||||
password: 'password123',
|
||||
})
|
||||
})
|
||||
|
||||
it('should show validation errors', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<LoginForm onSubmit={vi.fn()} />)
|
||||
|
||||
// Submit empty form
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should validate email format', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<LoginForm onSubmit={vi.fn()} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should disable submit button while submitting', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
||||
|
||||
render(<LoginForm onSubmit={onSubmit} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
|
||||
await user.type(screen.getByLabelText(/password/i), 'password123')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Data-Driven Tests with test.each
|
||||
|
||||
```typescript
|
||||
describe('StatusBadge', () => {
|
||||
test.each([
|
||||
['success', 'bg-green-500'],
|
||||
['warning', 'bg-yellow-500'],
|
||||
['error', 'bg-red-500'],
|
||||
['info', 'bg-blue-500'],
|
||||
])('should apply correct class for %s status', (status, expectedClass) => {
|
||||
render(<StatusBadge status={status} />)
|
||||
|
||||
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
|
||||
})
|
||||
|
||||
test.each([
|
||||
{ input: null, expected: 'Unknown' },
|
||||
{ input: undefined, expected: 'Unknown' },
|
||||
{ input: '', expected: 'Unknown' },
|
||||
{ input: 'invalid', expected: 'Unknown' },
|
||||
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
|
||||
render(<StatusBadge status={input} />)
|
||||
|
||||
expect(screen.getByText(expected)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
```typescript
|
||||
// Print entire DOM
|
||||
screen.debug()
|
||||
|
||||
// Print specific element
|
||||
screen.debug(screen.getByRole('button'))
|
||||
|
||||
// Log testing playground URL
|
||||
screen.logTestingPlaygroundURL()
|
||||
|
||||
// Pretty print DOM
|
||||
import { prettyDOM } from '@testing-library/react'
|
||||
console.log(prettyDOM(screen.getByRole('dialog')))
|
||||
|
||||
// Check available roles
|
||||
import { getRoles } from '@testing-library/react'
|
||||
console.log(getRoles(container))
|
||||
```
|
||||
|
||||
## Common Mistakes to Avoid
|
||||
|
||||
### ❌ Don't Use Implementation Details
|
||||
|
||||
```typescript
|
||||
// Bad - testing implementation
|
||||
expect(component.state.isOpen).toBe(true)
|
||||
expect(wrapper.find('.internal-class').length).toBe(1)
|
||||
|
||||
// Good - testing behavior
|
||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### ❌ Don't Forget Cleanup
|
||||
|
||||
```typescript
|
||||
// Bad - may leak state between tests
|
||||
it('test 1', () => {
|
||||
render(<Component />)
|
||||
})
|
||||
|
||||
// Good - cleanup is automatic with RTL, but reset mocks
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
```
|
||||
|
||||
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
|
||||
|
||||
```typescript
|
||||
// ❌ Bad - hardcoded strings are brittle
|
||||
expect(screen.getByText('Submit Form')).toBeInTheDocument()
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - role-based queries (most semantic)
|
||||
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - pattern matching (flexible)
|
||||
expect(screen.getByText(/submit/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - test behavior, not exact UI text
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(screen.getByRole('alert')).toBeInTheDocument()
|
||||
```
|
||||
|
||||
**Why prefer black-box assertions?**
|
||||
|
||||
- Text content may change (i18n, copy updates)
|
||||
- Role-based queries test accessibility
|
||||
- Pattern matching is resilient to minor changes
|
||||
- Tests focus on behavior, not implementation details
|
||||
|
||||
### ❌ Don't Assert on Absence Without Query
|
||||
|
||||
```typescript
|
||||
// Bad - throws if not found
|
||||
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
|
||||
|
||||
// Good - use queryBy for absence assertions
|
||||
expect(screen.queryByText('Error')).not.toBeInTheDocument()
|
||||
```
|
||||
|
|
@ -0,0 +1,523 @@
|
|||
# Domain-Specific Component Testing
|
||||
|
||||
This guide covers testing patterns for Dify's domain-specific components.
|
||||
|
||||
## Workflow Components (`workflow/`)
|
||||
|
||||
Workflow components handle node configuration, data flow, and graph operations.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **Node Configuration**
|
||||
1. **Data Validation**
|
||||
1. **Variable Passing**
|
||||
1. **Edge Connections**
|
||||
1. **Error Handling**
|
||||
|
||||
### Example: Node Configuration Panel
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import NodeConfigPanel from './node-config-panel'
|
||||
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
||||
|
||||
// Mock workflow context
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useWorkflowStore: () => mockWorkflowStore,
|
||||
useNodesInteractions: () => mockNodesInteractions,
|
||||
}))
|
||||
|
||||
let mockWorkflowStore = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
updateNode: vi.fn(),
|
||||
}
|
||||
|
||||
let mockNodesInteractions = {
|
||||
handleNodeSelect: vi.fn(),
|
||||
handleNodeDelete: vi.fn(),
|
||||
}
|
||||
|
||||
describe('NodeConfigPanel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWorkflowStore = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
updateNode: vi.fn(),
|
||||
}
|
||||
})
|
||||
|
||||
describe('Node Configuration', () => {
|
||||
it('should render node type selector', () => {
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update node config on change', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
|
||||
|
||||
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
|
||||
node.id,
|
||||
expect.objectContaining({ model: 'gpt-4' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Data Validation', () => {
|
||||
it('should show error for invalid input', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'code' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
// Enter invalid code
|
||||
const codeInput = screen.getByLabelText(/code/i)
|
||||
await user.clear(codeInput)
|
||||
await user.type(codeInput, 'invalid syntax {{{')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should validate required fields', async () => {
|
||||
const node = createMockNode({ type: 'http', data: { url: '' } })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Variable Passing', () => {
|
||||
it('should display available variables from upstream nodes', () => {
|
||||
const upstreamNode = createMockNode({
|
||||
id: 'node-1',
|
||||
type: 'start',
|
||||
data: { outputs: [{ name: 'user_input', type: 'string' }] },
|
||||
})
|
||||
const currentNode = createMockNode({
|
||||
id: 'node-2',
|
||||
type: 'llm',
|
||||
})
|
||||
|
||||
mockWorkflowStore.nodes = [upstreamNode, currentNode]
|
||||
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
|
||||
|
||||
render(<NodeConfigPanel node={currentNode} />)
|
||||
|
||||
// Variable selector should show upstream variables
|
||||
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
|
||||
|
||||
expect(screen.getByText('user_input')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should insert variable into prompt template', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
// Click variable button
|
||||
await user.click(screen.getByRole('button', { name: /insert variable/i }))
|
||||
await user.click(screen.getByText('user_input'))
|
||||
|
||||
const promptInput = screen.getByLabelText(/prompt/i)
|
||||
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Dataset Components (`dataset/`)
|
||||
|
||||
Dataset components handle file uploads, data display, and search/filter operations.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **File Upload**
|
||||
1. **File Type Validation**
|
||||
1. **Pagination**
|
||||
1. **Search & Filtering**
|
||||
1. **Data Format Handling**
|
||||
|
||||
### Example: Document Uploader
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import DocumentUploader from './document-uploader'
|
||||
|
||||
vi.mock('@/service/datasets', () => ({
|
||||
uploadDocument: vi.fn(),
|
||||
parseDocument: vi.fn(),
|
||||
}))
|
||||
|
||||
import * as datasetService from '@/service/datasets'
|
||||
const mockedService = vi.mocked(datasetService)
|
||||
|
||||
describe('DocumentUploader', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('File Upload', () => {
|
||||
it('should accept valid file types', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpload = vi.fn()
|
||||
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
||||
|
||||
render(<DocumentUploader onUpload={onUpload} />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const input = screen.getByLabelText(/upload/i)
|
||||
|
||||
await user.upload(input, file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
|
||||
expect.any(FormData)
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should reject invalid file types', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
|
||||
const input = screen.getByLabelText(/upload/i)
|
||||
|
||||
await user.upload(input, file)
|
||||
|
||||
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
|
||||
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show upload progress', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Mock upload with progress
|
||||
mockedService.uploadDocument.mockImplementation(() => {
|
||||
return new Promise((resolve) => {
|
||||
setTimeout(() => resolve({ id: 'doc-1' }), 100)
|
||||
})
|
||||
})
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
expect(screen.getByRole('progressbar')).toBeInTheDocument()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle upload failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should allow retry after failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.uploadDocument
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockResolvedValueOnce({ id: 'doc-1' })
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /retry/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Example: Document List with Pagination
|
||||
|
||||
```typescript
|
||||
describe('DocumentList', () => {
|
||||
describe('Pagination', () => {
|
||||
it('should load first page on mount', async () => {
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '1', name: 'Doc 1' }],
|
||||
total: 50,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
|
||||
})
|
||||
|
||||
it('should navigate to next page', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '1', name: 'Doc 1' }],
|
||||
total: 50,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '11', name: 'Doc 11' }],
|
||||
total: 50,
|
||||
page: 2,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /next/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 11')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Search & Filtering', () => {
|
||||
it('should filter by search query', async () => {
|
||||
const user = userEvent.setup()
|
||||
vi.useFakeTimers()
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
||||
|
||||
// Debounce
|
||||
vi.advanceTimersByTime(300)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
||||
'ds-1',
|
||||
expect.objectContaining({ search: 'test query' })
|
||||
)
|
||||
})
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration Components (`app/configuration/`, `config/`)
|
||||
|
||||
Configuration components handle forms, validation, and data persistence.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **Form Validation**
|
||||
1. **Save/Reset**
|
||||
1. **Required vs Optional Fields**
|
||||
1. **Configuration Persistence**
|
||||
1. **Error Feedback**
|
||||
|
||||
### Example: App Configuration Form
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import AppConfigForm from './app-config-form'
|
||||
|
||||
vi.mock('@/service/apps', () => ({
|
||||
updateAppConfig: vi.fn(),
|
||||
getAppConfig: vi.fn(),
|
||||
}))
|
||||
|
||||
import * as appService from '@/service/apps'
|
||||
const mockedService = vi.mocked(appService)
|
||||
|
||||
describe('AppConfigForm', () => {
|
||||
const defaultConfig = {
|
||||
name: 'My App',
|
||||
description: '',
|
||||
icon: 'default',
|
||||
openingStatement: '',
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
||||
})
|
||||
|
||||
describe('Form Validation', () => {
|
||||
it('should require app name', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Clear name field
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
|
||||
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate name length', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Enter very long name
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
|
||||
|
||||
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should allow empty optional fields', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockResolvedValue({ success: true })
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Leave description empty (optional)
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.updateAppConfig).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Save/Reset Functionality', () => {
|
||||
it('should save configuration', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockResolvedValue({ success: true })
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'Updated App')
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
|
||||
'app-1',
|
||||
expect.objectContaining({ name: 'Updated App' })
|
||||
)
|
||||
})
|
||||
|
||||
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should reset to default values', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Make changes
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
|
||||
|
||||
// Reset
|
||||
await user.click(screen.getByRole('button', { name: /reset/i }))
|
||||
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
it('should show unsaved changes warning', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Make changes
|
||||
await user.type(screen.getByLabelText(/name/i), ' Updated')
|
||||
|
||||
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should show error on save failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
|
@ -0,0 +1,366 @@
|
|||
# Mocking Guide for Dify Frontend Tests
|
||||
|
||||
## ⚠️ Important: What NOT to Mock
|
||||
|
||||
### DO NOT Mock Base Components
|
||||
|
||||
**Never mock components from `@/app/components/base/`** such as:
|
||||
|
||||
- `Loading`, `Spinner`
|
||||
- `Button`, `Input`, `Select`
|
||||
- `Tooltip`, `Modal`, `Dropdown`
|
||||
- `Icon`, `Badge`, `Tag`
|
||||
|
||||
**Why?**
|
||||
|
||||
- Base components will have their own dedicated tests
|
||||
- Mocking them creates false positives (tests pass but real integration fails)
|
||||
- Using real components tests actual integration behavior
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components
|
||||
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Button from '@/app/components/base/button'
|
||||
// They will render normally in tests
|
||||
```
|
||||
|
||||
### What TO Mock
|
||||
|
||||
Only mock these categories:
|
||||
|
||||
1. **API services** (`@/service/*`) - Network calls
|
||||
1. **Complex context providers** - When setup is too difficult
|
||||
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
|
||||
1. **i18n** - Always mock to return keys
|
||||
|
||||
## Mock Placement
|
||||
|
||||
| Location | Purpose |
|
||||
|----------|---------|
|
||||
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
|
||||
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
|
||||
| Test file | Test-specific mocks, inline with `vi.mock()` |
|
||||
|
||||
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
|
||||
|
||||
## Essential Mocks
|
||||
|
||||
### 1. i18n (Auto-loaded via Global Mock)
|
||||
|
||||
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
||||
|
||||
For tests requiring custom translations, override the mock:
|
||||
|
||||
```typescript
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
'my.custom.key': 'Custom translation',
|
||||
}
|
||||
return translations[key] || key
|
||||
},
|
||||
}),
|
||||
}))
|
||||
```
|
||||
|
||||
### 2. Next.js Router
|
||||
|
||||
```typescript
|
||||
const mockPush = vi.fn()
|
||||
const mockReplace = vi.fn()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
replace: mockReplace,
|
||||
back: vi.fn(),
|
||||
prefetch: vi.fn(),
|
||||
}),
|
||||
usePathname: () => '/current-path',
|
||||
useSearchParams: () => new URLSearchParams('?key=value'),
|
||||
}))
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should navigate on click', () => {
|
||||
render(<Component />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(mockPush).toHaveBeenCalledWith('/expected-path')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Portal Components (with Shared State)
|
||||
|
||||
```typescript
|
||||
// ⚠️ Important: Use shared state for components that depend on each other
|
||||
let mockPortalOpenState = false
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
||||
mockPortalOpenState = open || false // Update shared state
|
||||
return <div data-testid="portal" data-open={open}>{children}</div>
|
||||
},
|
||||
PortalToFollowElemContent: ({ children }: any) => {
|
||||
// ✅ Matches actual: returns null when portal is closed
|
||||
if (!mockPortalOpenState) return null
|
||||
return <div data-testid="portal-content">{children}</div>
|
||||
},
|
||||
PortalToFollowElemTrigger: ({ children }: any) => (
|
||||
<div data-testid="portal-trigger">{children}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockPortalOpenState = false // ✅ Reset shared state
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 4. API Service Mocks
|
||||
|
||||
```typescript
|
||||
import * as api from '@/service/api'
|
||||
|
||||
vi.mock('@/service/api')
|
||||
|
||||
const mockedApi = vi.mocked(api)
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Setup default mock implementation
|
||||
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
|
||||
|
||||
render(<Component />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
render(<Component />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 5. HTTP Mocking with Nock
|
||||
|
||||
```typescript
|
||||
import nock from 'nock'
|
||||
|
||||
const GITHUB_HOST = 'https://api.github.com'
|
||||
const GITHUB_PATH = '/repos/owner/repo'
|
||||
|
||||
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
|
||||
return nock(GITHUB_HOST)
|
||||
.get(GITHUB_PATH)
|
||||
.delay(delayMs)
|
||||
.reply(status, body)
|
||||
}
|
||||
|
||||
describe('GithubComponent', () => {
|
||||
afterEach(() => {
|
||||
nock.cleanAll()
|
||||
})
|
||||
|
||||
it('should display repo info', async () => {
|
||||
mockGithubApi(200, { name: 'dify', stars: 1000 })
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('dify')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API error', async () => {
|
||||
mockGithubApi(500, { message: 'Server error' })
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 6. Context Providers
|
||||
|
||||
```typescript
|
||||
import { ProviderContext } from '@/context/provider-context'
|
||||
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
|
||||
|
||||
describe('Component with Context', () => {
|
||||
it('should render for free plan', () => {
|
||||
const mockContext = createMockPlan('sandbox')
|
||||
|
||||
render(
|
||||
<ProviderContext.Provider value={mockContext}>
|
||||
<Component />
|
||||
</ProviderContext.Provider>
|
||||
)
|
||||
|
||||
expect(screen.getByText('Upgrade')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render for pro plan', () => {
|
||||
const mockContext = createMockPlan('professional')
|
||||
|
||||
render(
|
||||
<ProviderContext.Provider value={mockContext}>
|
||||
<Component />
|
||||
</ProviderContext.Provider>
|
||||
)
|
||||
|
||||
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 7. SWR / React Query
|
||||
|
||||
```typescript
|
||||
// SWR
|
||||
vi.mock('swr', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
import useSWR from 'swr'
|
||||
const mockedUseSWR = vi.mocked(useSWR)
|
||||
|
||||
describe('Component with SWR', () => {
|
||||
it('should show loading state', () => {
|
||||
mockedUseSWR.mockReturnValue({
|
||||
data: undefined,
|
||||
error: undefined,
|
||||
isLoading: true,
|
||||
})
|
||||
|
||||
render(<Component />)
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// React Query
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
})
|
||||
|
||||
const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
const queryClient = createTestQueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Mock Best Practices
|
||||
|
||||
### ✅ DO
|
||||
|
||||
1. **Use real base components** - Import from `@/app/components/base/` directly
|
||||
1. **Use real project components** - Prefer importing over mocking
|
||||
1. **Reset mocks in `beforeEach`**, not `afterEach`
|
||||
1. **Match actual component behavior** in mocks (when mocking is necessary)
|
||||
1. **Use factory functions** for complex mock data
|
||||
1. **Import actual types** for type safety
|
||||
1. **Reset shared mock state** in `beforeEach`
|
||||
|
||||
### ❌ DON'T
|
||||
|
||||
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
1. Don't mock components you can import directly
|
||||
1. Don't create overly simplified mocks that miss conditional logic
|
||||
1. Don't forget to clean up nock after each test
|
||||
1. Don't use `any` types in mocks without necessity
|
||||
|
||||
### Mock Decision Tree
|
||||
|
||||
```
|
||||
Need to use a component in test?
|
||||
│
|
||||
├─ Is it from @/app/components/base/*?
|
||||
│ └─ YES → Import real component, DO NOT mock
|
||||
│
|
||||
├─ Is it a project component?
|
||||
│ └─ YES → Prefer importing real component
|
||||
│ Only mock if setup is extremely complex
|
||||
│
|
||||
├─ Is it an API service (@/service/*)?
|
||||
│ └─ YES → Mock it
|
||||
│
|
||||
├─ Is it a third-party lib with side effects?
|
||||
│ └─ YES → Mock it (next/navigation, external SDKs)
|
||||
│
|
||||
└─ Is it i18n?
|
||||
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
|
||||
```
|
||||
|
||||
## Factory Function Pattern
|
||||
|
||||
```typescript
|
||||
// __mocks__/data-factories.ts
|
||||
import type { User, Project } from '@/types'
|
||||
|
||||
export const createMockUser = (overrides: Partial<User> = {}): User => ({
|
||||
id: 'user-1',
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
role: 'member',
|
||||
createdAt: new Date().toISOString(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
|
||||
id: 'project-1',
|
||||
name: 'Test Project',
|
||||
description: 'A test project',
|
||||
owner: createMockUser(),
|
||||
members: [],
|
||||
createdAt: new Date().toISOString(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
// Usage in tests
|
||||
it('should display project owner', () => {
|
||||
const project = createMockProject({
|
||||
owner: createMockUser({ name: 'John Doe' }),
|
||||
})
|
||||
|
||||
render(<ProjectCard project={project} />)
|
||||
expect(screen.getByText('John Doe')).toBeInTheDocument()
|
||||
})
|
||||
```
|
||||
|
|
@ -0,0 +1,269 @@
|
|||
# Testing Workflow Guide
|
||||
|
||||
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
|
||||
|
||||
## Scope Clarification
|
||||
|
||||
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
|
||||
|
||||
| Scope | Rule |
|
||||
|-------|------|
|
||||
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
|
||||
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
|
||||
|
||||
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
|
||||
|
||||
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
|
||||
|
||||
### Why Incremental?
|
||||
|
||||
| Batch Approach (❌) | Incremental Approach (✅) |
|
||||
|---------------------|---------------------------|
|
||||
| Generate 5+ tests at once | Generate 1 test at a time |
|
||||
| Run tests only at the end | Run test immediately after each file |
|
||||
| Multiple failures compound | Single point of failure, easy to debug |
|
||||
| Hard to identify root cause | Clear cause-effect relationship |
|
||||
| Mock issues affect many files | Mock issues caught early |
|
||||
| Messy git history | Clean, atomic commits possible |
|
||||
|
||||
## Single File Workflow
|
||||
|
||||
When testing a **single component, hook, or utility**:
|
||||
|
||||
```
|
||||
1. Read source code completely
|
||||
2. Run `pnpm analyze-component <path>` (if available)
|
||||
3. Check complexity score and features detected
|
||||
4. Write the test file
|
||||
5. Run test: `pnpm test -- <file>.spec.tsx`
|
||||
6. Fix any failures
|
||||
7. Verify coverage meets goals (100% function, >95% branch)
|
||||
```
|
||||
|
||||
## Directory/Multi-File Workflow (MUST FOLLOW)
|
||||
|
||||
When testing a **directory or multiple files**, follow this strict workflow:
|
||||
|
||||
### Step 1: Analyze and Plan
|
||||
|
||||
1. **List all files** that need tests in the directory
|
||||
1. **Categorize by complexity**:
|
||||
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
|
||||
- 🟡 **Medium**: Components with state, effects, or event handlers
|
||||
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
|
||||
1. **Order by dependency**: Test dependencies before dependents
|
||||
1. **Create a todo list** to track progress
|
||||
|
||||
### Step 2: Determine Processing Order
|
||||
|
||||
Process files in this recommended order:
|
||||
|
||||
```
|
||||
1. Utility functions (simplest, no React)
|
||||
2. Custom hooks (isolated logic)
|
||||
3. Simple presentational components (few/no props)
|
||||
4. Medium complexity components (state, effects)
|
||||
5. Complex components (API, routing, many deps)
|
||||
6. Container/index components (integration tests - last)
|
||||
```
|
||||
|
||||
**Rationale**:
|
||||
|
||||
- Simpler files help establish mock patterns
|
||||
- Hooks used by components should be tested first
|
||||
- Integration tests (index files) depend on child components working
|
||||
|
||||
### Step 3: Process Each File Incrementally
|
||||
|
||||
**For EACH file in the ordered list:**
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────┐
|
||||
│ 1. Write test file │
|
||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
||||
│ 3. If FAIL → Fix immediately, re-run │
|
||||
│ 4. If PASS → Mark complete in todo list │
|
||||
│ 5. ONLY THEN proceed to next file │
|
||||
└─────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**DO NOT proceed to the next file until the current one passes.**
|
||||
|
||||
### Step 4: Final Verification
|
||||
|
||||
After all individual tests pass:
|
||||
|
||||
```bash
|
||||
# Run all tests in the directory together
|
||||
pnpm test -- path/to/directory/
|
||||
|
||||
# Check coverage
|
||||
pnpm test -- --coverage path/to/directory/
|
||||
```
|
||||
|
||||
## Component Complexity Guidelines
|
||||
|
||||
Use `pnpm analyze-component <path>` to assess complexity before testing.
|
||||
|
||||
### 🔴 Very Complex Components (Complexity > 50)
|
||||
|
||||
**Consider refactoring BEFORE testing:**
|
||||
|
||||
- Break component into smaller, testable pieces
|
||||
- Extract complex logic into custom hooks
|
||||
- Separate container and presentational layers
|
||||
|
||||
**If testing as-is:**
|
||||
|
||||
- Use integration tests for complex workflows
|
||||
- Use `test.each()` for data-driven testing
|
||||
- Multiple `describe` blocks for organization
|
||||
- Consider testing major sections separately
|
||||
|
||||
### 🟡 Medium Complexity (Complexity 30-50)
|
||||
|
||||
- Group related tests in `describe` blocks
|
||||
- Test integration scenarios between internal parts
|
||||
- Focus on state transitions and side effects
|
||||
- Use helper functions to reduce test complexity
|
||||
|
||||
### 🟢 Simple Components (Complexity < 30)
|
||||
|
||||
- Standard test structure
|
||||
- Focus on props, rendering, and edge cases
|
||||
- Usually straightforward to test
|
||||
|
||||
### 📏 Large Files (500+ lines)
|
||||
|
||||
Regardless of complexity score:
|
||||
|
||||
- **Strongly consider refactoring** before testing
|
||||
- If testing as-is, test major sections separately
|
||||
- Create helper functions for test setup
|
||||
- May need multiple test files
|
||||
|
||||
## Todo List Format
|
||||
|
||||
When testing multiple files, use a todo list like this:
|
||||
|
||||
```
|
||||
Testing: path/to/directory/
|
||||
|
||||
Ordered by complexity (simple → complex):
|
||||
|
||||
☐ utils/helper.ts [utility, simple]
|
||||
☐ hooks/use-custom-hook.ts [hook, simple]
|
||||
☐ empty-state.tsx [component, simple]
|
||||
☐ item-card.tsx [component, medium]
|
||||
☐ list.tsx [component, complex]
|
||||
☐ index.tsx [integration]
|
||||
|
||||
Progress: 0/6 complete
|
||||
```
|
||||
|
||||
Update status as you complete each:
|
||||
|
||||
- ☐ → ⏳ (in progress)
|
||||
- ⏳ → ✅ (complete and verified)
|
||||
- ⏳ → ❌ (blocked, needs attention)
|
||||
|
||||
## When to Stop and Verify
|
||||
|
||||
**Always run tests after:**
|
||||
|
||||
- Completing a test file
|
||||
- Making changes to fix a failure
|
||||
- Modifying shared mocks
|
||||
- Updating test utilities or helpers
|
||||
|
||||
**Signs you should pause:**
|
||||
|
||||
- More than 2 consecutive test failures
|
||||
- Mock-related errors appearing
|
||||
- Unclear why a test is failing
|
||||
- Test passing but coverage unexpectedly low
|
||||
|
||||
## Common Pitfalls to Avoid
|
||||
|
||||
### ❌ Don't: Generate Everything First
|
||||
|
||||
```
|
||||
# BAD: Writing all files then testing
|
||||
Write component-a.spec.tsx
|
||||
Write component-b.spec.tsx
|
||||
Write component-c.spec.tsx
|
||||
Write component-d.spec.tsx
|
||||
Run pnpm test ← Multiple failures, hard to debug
|
||||
```
|
||||
|
||||
### ✅ Do: Verify Each Step
|
||||
|
||||
```
|
||||
# GOOD: Incremental with verification
|
||||
Write component-a.spec.tsx
|
||||
Run pnpm test -- component-a.spec.tsx ✅
|
||||
Write component-b.spec.tsx
|
||||
Run pnpm test -- component-b.spec.tsx ✅
|
||||
...continue...
|
||||
```
|
||||
|
||||
### ❌ Don't: Skip Verification for "Simple" Components
|
||||
|
||||
Even simple components can have:
|
||||
|
||||
- Import errors
|
||||
- Missing mock setup
|
||||
- Incorrect assumptions about props
|
||||
|
||||
**Always verify, regardless of perceived simplicity.**
|
||||
|
||||
### ❌ Don't: Continue When Tests Fail
|
||||
|
||||
Failing tests compound:
|
||||
|
||||
- A mock issue in file A affects files B, C, D
|
||||
- Fixing A later requires revisiting all dependent tests
|
||||
- Time wasted on debugging cascading failures
|
||||
|
||||
**Fix failures immediately before proceeding.**
|
||||
|
||||
## Integration with Claude's Todo Feature
|
||||
|
||||
When using Claude for multi-file testing:
|
||||
|
||||
1. **Ask Claude to create a todo list** before starting
|
||||
1. **Request one file at a time** or ensure Claude processes incrementally
|
||||
1. **Verify each test passes** before asking for the next
|
||||
1. **Mark todos complete** as you progress
|
||||
|
||||
Example prompt:
|
||||
|
||||
```
|
||||
Test all components in `path/to/directory/`.
|
||||
First, analyze the directory and create a todo list ordered by complexity.
|
||||
Then, process ONE file at a time, waiting for my confirmation that tests pass
|
||||
before proceeding to the next.
|
||||
```
|
||||
|
||||
## Summary Checklist
|
||||
|
||||
Before starting multi-file testing:
|
||||
|
||||
- [ ] Listed all files needing tests
|
||||
- [ ] Ordered by complexity (simple → complex)
|
||||
- [ ] Created todo list for tracking
|
||||
- [ ] Understand dependencies between files
|
||||
|
||||
During testing:
|
||||
|
||||
- [ ] Processing ONE file at a time
|
||||
- [ ] Running tests after EACH file
|
||||
- [ ] Fixing failures BEFORE proceeding
|
||||
- [ ] Updating todo list progress
|
||||
|
||||
After completion:
|
||||
|
||||
- [ ] All individual tests pass
|
||||
- [ ] Full directory test run passes
|
||||
- [ ] Coverage goals met
|
||||
- [ ] Todo list shows all complete
|
||||
|
|
@ -0,0 +1 @@
|
|||
../.claude/skills
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
[run]
|
||||
omit =
|
||||
api/tests/*
|
||||
api/migrations/*
|
||||
api/core/rag/datasource/vdb/*
|
||||
|
|
@ -6,6 +6,9 @@
|
|||
"context": "..",
|
||||
"dockerfile": "Dockerfile"
|
||||
},
|
||||
"mounts": [
|
||||
"source=dify-dev-tmp,target=/tmp,type=volume"
|
||||
],
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"nodeGypDependencies": true,
|
||||
|
|
@ -34,19 +37,13 @@
|
|||
},
|
||||
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
||||
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
// Use 'postCreateCommand' to run commands after the container is created.
|
||||
// "postCreateCommand": "python --version",
|
||||
|
||||
// Configure tool-specific properties.
|
||||
// "customizations": {},
|
||||
|
||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||
// "remoteUser": "root"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
|
|
|||
|
|
@ -6,229 +6,244 @@
|
|||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
/api/ @QuantumGhost
|
||||
|
||||
# Backend - MCP
|
||||
api/core/mcp/ @Nov1c444
|
||||
api/core/entities/mcp_provider.py @Nov1c444
|
||||
api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
api/controllers/mcp/ @Nov1c444
|
||||
api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
api/tests/**/*mcp* @Nov1c444
|
||||
/api/core/mcp/ @Nov1c444
|
||||
/api/core/entities/mcp_provider.py @Nov1c444
|
||||
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
/api/controllers/mcp/ @Nov1c444
|
||||
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
/api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
/api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||
api/core/workflow/nodes/agent/ @Nov1c444
|
||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
api/core/workflow/nodes/loop/ @Nov1c444
|
||||
api/core/workflow/nodes/llm/ @Nov1c444
|
||||
/api/core/workflow/nodes/agent/ @Nov1c444
|
||||
/api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
/api/core/workflow/nodes/loop/ @Nov1c444
|
||||
/api/core/workflow/nodes/llm/ @Nov1c444
|
||||
|
||||
# Backend - RAG (Retrieval Augmented Generation)
|
||||
api/core/rag/ @JohnJyong
|
||||
api/services/rag_pipeline/ @JohnJyong
|
||||
api/services/dataset_service.py @JohnJyong
|
||||
api/services/knowledge_service.py @JohnJyong
|
||||
api/services/external_knowledge_service.py @JohnJyong
|
||||
api/services/hit_testing_service.py @JohnJyong
|
||||
api/services/metadata_service.py @JohnJyong
|
||||
api/services/vector_service.py @JohnJyong
|
||||
api/services/entities/knowledge_entities/ @JohnJyong
|
||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||
api/controllers/console/datasets/ @JohnJyong
|
||||
api/controllers/service_api/dataset/ @JohnJyong
|
||||
api/models/dataset.py @JohnJyong
|
||||
api/tasks/rag_pipeline/ @JohnJyong
|
||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
||||
api/tasks/document_indexing_task.py @JohnJyong
|
||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||
api/tasks/clean_dataset_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
/api/core/rag/ @JohnJyong
|
||||
/api/services/rag_pipeline/ @JohnJyong
|
||||
/api/services/dataset_service.py @JohnJyong
|
||||
/api/services/knowledge_service.py @JohnJyong
|
||||
/api/services/external_knowledge_service.py @JohnJyong
|
||||
/api/services/hit_testing_service.py @JohnJyong
|
||||
/api/services/metadata_service.py @JohnJyong
|
||||
/api/services/vector_service.py @JohnJyong
|
||||
/api/services/entities/knowledge_entities/ @JohnJyong
|
||||
/api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||
/api/controllers/console/datasets/ @JohnJyong
|
||||
/api/controllers/service_api/dataset/ @JohnJyong
|
||||
/api/models/dataset.py @JohnJyong
|
||||
/api/tasks/rag_pipeline/ @JohnJyong
|
||||
/api/tasks/add_document_to_index_task.py @JohnJyong
|
||||
/api/tasks/batch_clean_document_task.py @JohnJyong
|
||||
/api/tasks/clean_document_task.py @JohnJyong
|
||||
/api/tasks/clean_notion_document_task.py @JohnJyong
|
||||
/api/tasks/document_indexing_task.py @JohnJyong
|
||||
/api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||
/api/tasks/document_indexing_update_task.py @JohnJyong
|
||||
/api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||
/api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||
/api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||
/api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||
/api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||
/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||
/api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||
/api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||
/api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||
/api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||
/api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||
/api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||
/api/tasks/clean_dataset_task.py @JohnJyong
|
||||
/api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
|
||||
# Backend - Plugins
|
||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
api/core/trigger/ @Mairuis @Yeuoly
|
||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
api/services/trigger/ @Mairuis @Yeuoly
|
||||
api/models/trigger.py @Mairuis @Yeuoly
|
||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
/api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
/api/core/trigger/ @Mairuis @Yeuoly
|
||||
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
/api/services/trigger/ @Mairuis @Yeuoly
|
||||
/api/models/trigger.py @Mairuis @Yeuoly
|
||||
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
/api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Async Workflow
|
||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
/api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Billing
|
||||
api/services/billing_service.py @hj24 @zyssyz123
|
||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
||||
/api/services/billing_service.py @hj24 @zyssyz123
|
||||
/api/controllers/console/billing/ @hj24 @zyssyz123
|
||||
|
||||
# Backend - Enterprise
|
||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/feature_service.py @GarfieldDai @GareArc
|
||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
/api/configs/enterprise/ @GarfieldDai @GareArc
|
||||
/api/services/enterprise/ @GarfieldDai @GareArc
|
||||
/api/services/feature_service.py @GarfieldDai @GareArc
|
||||
/api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
/api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
/api/migrations/ @snakevash @laipz8200 @MRZHUH
|
||||
|
||||
# Backend - Vector DB Middleware
|
||||
/api/configs/middleware/vdb/* @JohnJyong
|
||||
|
||||
# Frontend
|
||||
web/ @iamjoel
|
||||
/web/ @iamjoel
|
||||
|
||||
# Frontend - Web Tests
|
||||
/.github/workflows/web-tests.yml @iamjoel
|
||||
|
||||
# Frontend - App - Orchestration
|
||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||
/web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
/web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
/web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||
/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Chat
|
||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||
/web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Completion
|
||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||
/web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - App - List and Creation
|
||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Monitoring
|
||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Settings
|
||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - Hit Testing
|
||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - List and Creation
|
||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||
/web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||
/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||
/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - RAG - Documents List
|
||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||
/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Segments List
|
||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Settings
|
||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - Ecosystem - Plugins
|
||||
web/app/components/plugins/ @iamjoel @zhsama
|
||||
/web/app/components/plugins/ @iamjoel @zhsama
|
||||
|
||||
# Frontend - Ecosystem - Tools
|
||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
||||
/web/app/components/tools/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Ecosystem - MarketPlace
|
||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||
/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Login and Registration
|
||||
web/app/signin/ @douxc @iamjoel
|
||||
web/app/signup/ @douxc @iamjoel
|
||||
web/app/reset-password/ @douxc @iamjoel
|
||||
web/app/install/ @douxc @iamjoel
|
||||
web/app/init/ @douxc @iamjoel
|
||||
web/app/forgot-password/ @douxc @iamjoel
|
||||
web/app/account/ @douxc @iamjoel
|
||||
/web/app/signin/ @douxc @iamjoel
|
||||
/web/app/signup/ @douxc @iamjoel
|
||||
/web/app/reset-password/ @douxc @iamjoel
|
||||
/web/app/install/ @douxc @iamjoel
|
||||
/web/app/init/ @douxc @iamjoel
|
||||
/web/app/forgot-password/ @douxc @iamjoel
|
||||
/web/app/account/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Service Authentication
|
||||
web/service/base.ts @douxc @iamjoel
|
||||
/web/service/base.ts @douxc @iamjoel
|
||||
|
||||
# Frontend - WebApp Authentication and Access Control
|
||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||
/web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||
/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||
/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||
/web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Explore Page
|
||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
||||
/web/app/components/explore/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Personal Settings
|
||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||
/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||
/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Analytics
|
||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||
/web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Base Components
|
||||
web/app/components/base/ @iamjoel @zxhlyh
|
||||
/web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
web/utils/time.ts @iamjoel @zxhlyh
|
||||
web/utils/format.ts @iamjoel @zxhlyh
|
||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
/web/utils/time.ts @iamjoel @zxhlyh
|
||||
/web/utils/format.ts @iamjoel @zxhlyh
|
||||
/web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Billing and Education
|
||||
web/app/components/billing/ @iamjoel @zxhlyh
|
||||
web/app/education-apply/ @iamjoel @zxhlyh
|
||||
/web/app/components/billing/ @iamjoel @zxhlyh
|
||||
/web/app/education-apply/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Workspace
|
||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
|
||||
# Docker
|
||||
/docker/* @laipz8200
|
||||
|
|
|
|||
|
|
@ -1,12 +0,0 @@
|
|||
# Copilot Instructions
|
||||
|
||||
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
|
||||
|
||||
Key reminders:
|
||||
|
||||
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
|
||||
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
|
||||
- Target >95% line and branch coverage and 100% function/statement coverage.
|
||||
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
|
||||
|
||||
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
|
||||
|
|
@ -71,18 +71,18 @@ jobs:
|
|||
run: |
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Run Workflow
|
||||
run: uv run --project api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Run Tool
|
||||
run: uv run --project api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
- name: Run API Tests
|
||||
env:
|
||||
STORAGE_TYPE: opendal
|
||||
OPENDAL_SCHEME: fs
|
||||
OPENDAL_FS_ROOT: /tmp/dify-storage
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
uv run --project api pytest \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/workflow \
|
||||
api/tests/integration_tests/tools \
|
||||
api/tests/test_containers_integration_tests \
|
||||
api/tests/unit_tests
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
|
|
@ -93,5 +93,12 @@ jobs:
|
|||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
{
|
||||
echo ""
|
||||
echo "<details><summary>File-level coverage (click to expand)</summary>"
|
||||
echo ""
|
||||
echo '```'
|
||||
uv run --project api coverage report -m
|
||||
echo '```'
|
||||
echo "</details>"
|
||||
} >> $GITHUB_STEP_SUMMARY
|
||||
|
|
|
|||
|
|
@ -13,11 +13,12 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
|
|
@ -35,10 +36,11 @@ jobs:
|
|||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||
# ast-grep exits 1 if no matches are found; allow idempotent runs.
|
||||
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
|
||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||
cat > /tmp/optional-rule.yml << 'EOF'
|
||||
id: convert-optional-to-union
|
||||
|
|
@ -56,14 +58,15 @@ jobs:
|
|||
pattern: $T
|
||||
fix: $T | None
|
||||
EOF
|
||||
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
|
||||
uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
|
||||
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx mdformat .
|
||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
|
|
@ -76,7 +79,7 @@ jobs:
|
|||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
working-directory: ./web
|
||||
|
|
@ -84,7 +87,6 @@ jobs:
|
|||
|
||||
- name: oxlint
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpx oxlint --fix
|
||||
run: pnpm exec oxlint --config .oxlintrc.json --fix .
|
||||
|
||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ jobs:
|
|||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ jobs:
|
|||
with:
|
||||
node-version: 'lts/*'
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
|
|
@ -21,14 +22,7 @@ jobs:
|
|||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
|
|
@ -36,23 +30,342 @@ jobs:
|
|||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm test
|
||||
run: pnpm test --coverage
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
set -eo pipefail
|
||||
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
let libCoverage = null;
|
||||
|
||||
try {
|
||||
libCoverage = require('istanbul-lib-coverage');
|
||||
} catch (error) {
|
||||
libCoverage = null;
|
||||
}
|
||||
|
||||
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||
|
||||
const hasSummary = fs.existsSync(summaryPath);
|
||||
const hasFinal = fs.existsSync(finalPath);
|
||||
|
||||
if (!hasSummary && !hasFinal) {
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('No coverage data found.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const summary = hasSummary
|
||||
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||
: null;
|
||||
const coverage = hasFinal
|
||||
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||
: null;
|
||||
|
||||
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||
const lineHits = {};
|
||||
|
||||
if (!statementMap || !statementHits) {
|
||||
return lineHits;
|
||||
}
|
||||
|
||||
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||
const line = statement?.start?.line;
|
||||
if (!line) {
|
||||
return;
|
||||
}
|
||||
const hits = statementHits[key] ?? 0;
|
||||
const previous = lineHits[line];
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||
});
|
||||
|
||||
return lineHits;
|
||||
};
|
||||
|
||||
const getFileCoverage = (entry) => (
|
||||
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||
);
|
||||
|
||||
const getLineHits = (entry, fileCoverage) => {
|
||||
const lineHits = entry.l ?? {};
|
||||
if (Object.keys(lineHits).length > 0) {
|
||||
return lineHits;
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getLineCoverage();
|
||||
}
|
||||
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||
};
|
||||
|
||||
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||
return Object.entries(lineHits)
|
||||
.filter(([, count]) => count === 0)
|
||||
.map(([line]) => Number(line))
|
||||
.sort((a, b) => a - b);
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getUncoveredLines();
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const totals = {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
};
|
||||
const fileSummaries = [];
|
||||
|
||||
if (summary) {
|
||||
const totalEntry = summary.total ?? {};
|
||||
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||
if (totalEntry[key]) {
|
||||
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||
totals[key].total = totalEntry[key].total ?? 0;
|
||||
}
|
||||
});
|
||||
|
||||
Object.entries(summary)
|
||||
.filter(([file]) => file !== 'total')
|
||||
.forEach(([file, data]) => {
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||
lines: {
|
||||
covered: data.lines?.covered ?? 0,
|
||||
total: data.lines?.total ?? 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
} else if (coverage) {
|
||||
Object.entries(coverage).forEach(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
totals.lines.total += lineTotal;
|
||||
totals.lines.covered += lineCovered;
|
||||
totals.statements.total += statementTotal;
|
||||
totals.statements.covered += statementCovered;
|
||||
totals.branches.total += branchTotal;
|
||||
totals.branches.covered += branchCovered;
|
||||
totals.functions.total += functionTotal;
|
||||
totals.functions.covered += functionCovered;
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||
lines: {
|
||||
covered: lineCovered || statementCovered,
|
||||
total: lineTotal || statementTotal,
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('| Metric | Coverage | Covered / Total |');
|
||||
console.log('|--------|----------|-----------------|');
|
||||
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||
console.log('');
|
||||
console.log('```');
|
||||
fileSummaries
|
||||
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||
.slice(0, 25)
|
||||
.forEach(({ file, pct, lines }) => {
|
||||
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||
});
|
||||
console.log('```');
|
||||
console.log('</details>');
|
||||
|
||||
if (coverage) {
|
||||
const pctValue = (covered, tot) => {
|
||||
if (tot === 0) {
|
||||
return '0';
|
||||
}
|
||||
return ((covered / tot) * 100)
|
||||
.toFixed(2)
|
||||
.replace(/\.?0+$/, '');
|
||||
};
|
||||
|
||||
const formatLineRanges = (lines) => {
|
||||
if (lines.length === 0) {
|
||||
return '';
|
||||
}
|
||||
const ranges = [];
|
||||
let start = lines[0];
|
||||
let end = lines[0];
|
||||
|
||||
for (let i = 1; i < lines.length; i += 1) {
|
||||
const current = lines[i];
|
||||
if (current === end + 1) {
|
||||
end = current;
|
||||
continue;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
start = current;
|
||||
end = current;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
return ranges.join(',');
|
||||
};
|
||||
|
||||
const tableTotals = {
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
lines: { covered: 0, total: 0 },
|
||||
};
|
||||
const tableRows = Object.entries(coverage)
|
||||
.map(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
tableTotals.lines.total += lineTotal;
|
||||
tableTotals.lines.covered += lineCovered;
|
||||
tableTotals.statements.total += statementTotal;
|
||||
tableTotals.statements.covered += statementCovered;
|
||||
tableTotals.branches.total += branchTotal;
|
||||
tableTotals.branches.covered += branchCovered;
|
||||
tableTotals.functions.total += functionTotal;
|
||||
tableTotals.functions.covered += functionCovered;
|
||||
|
||||
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||
|
||||
const filePath = entry.path ?? file;
|
||||
const relativePath = path.isAbsolute(filePath)
|
||||
? path.relative(process.cwd(), filePath)
|
||||
: filePath;
|
||||
|
||||
return {
|
||||
file: relativePath || file,
|
||||
statements: pctValue(statementCovered, statementTotal),
|
||||
branches: pctValue(branchCovered, branchTotal),
|
||||
functions: pctValue(functionCovered, functionTotal),
|
||||
lines: pctValue(lineCovered, lineTotal),
|
||||
uncovered: formatLineRanges(uncoveredLines),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => a.file.localeCompare(b.file));
|
||||
|
||||
const columns = [
|
||||
{ key: 'file', header: 'File', align: 'left' },
|
||||
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||
];
|
||||
|
||||
const allFilesRow = {
|
||||
file: 'All files',
|
||||
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||
uncovered: '',
|
||||
};
|
||||
|
||||
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||
const formatRow = (row) => `| ${columns
|
||||
.map(({ key }) => String(row[key] ?? ''))
|
||||
.join(' | ')} |`;
|
||||
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||
const dividerRow = `| ${columns
|
||||
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||
.join(' | ')} |`;
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>Vitest coverage table</summary>');
|
||||
console.log('');
|
||||
console.log(headerRow);
|
||||
console.log(dividerRow);
|
||||
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||
console.log('</details>');
|
||||
}
|
||||
NODE
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
retention-days: 30
|
||||
if-no-files-found: error
|
||||
|
|
|
|||
|
|
@ -189,6 +189,7 @@ docker/volumes/matrixone/*
|
|||
docker/volumes/mysql/*
|
||||
docker/volumes/seekdb/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
docker/volumes/iris/*
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
docker/nginx/ssl/*
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
# Windsurf Testing Rules
|
||||
|
||||
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
|
||||
- Honor every requirement in that document when generating or accepting tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
|
|
@ -116,6 +116,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
|||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
||||
|
|
@ -133,6 +134,7 @@ HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
|||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||
HUAWEI_OBS_SERVER=your-server-url
|
||||
HUAWEI_OBS_PATH_STYLE=false
|
||||
|
||||
# Baidu OBS Storage Configuration
|
||||
BAIDU_OBS_BUCKET_NAME=your-bucket-name
|
||||
|
|
@ -543,6 +545,25 @@ APP_MAX_EXECUTION_TIME=1200
|
|||
APP_DEFAULT_ACTIVE_REQUESTS=0
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
# Aliyun SLS Logstore Configuration
|
||||
# Aliyun Access Key ID
|
||||
ALIYUN_SLS_ACCESS_KEY_ID=
|
||||
# Aliyun Access Key Secret
|
||||
ALIYUN_SLS_ACCESS_KEY_SECRET=
|
||||
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
|
||||
ALIYUN_SLS_ENDPOINT=
|
||||
# Aliyun SLS Region (e.g., cn-hangzhou)
|
||||
ALIYUN_SLS_REGION=
|
||||
# Aliyun SLS Project Name
|
||||
ALIYUN_SLS_PROJECT_NAME=
|
||||
# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage)
|
||||
ALIYUN_SLS_LOGSTORE_TTL=365
|
||||
# Enable dual-write to both SLS LogStore and SQL database (default: false)
|
||||
LOGSTORE_DUAL_WRITE_ENABLED=false
|
||||
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
|
||||
# Useful for migration scenarios where historical data exists only in SQL database
|
||||
LOGSTORE_DUAL_READ_ENABLED=true
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
|
|
@ -672,7 +693,6 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
|
|||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
||||
# Maximum number of concurrent annotation import tasks per tenant
|
||||
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
|
||||
# Sandbox expired records clean configuration
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_import_modules,
|
||||
ext_logging,
|
||||
ext_login,
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_orjson,
|
||||
|
|
@ -83,6 +84,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_redis,
|
||||
ext_request_logging,
|
||||
ext_sentry,
|
||||
ext_session_factory,
|
||||
ext_set_secretkey,
|
||||
ext_storage,
|
||||
ext_timezone,
|
||||
|
|
@ -104,6 +106,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_migrate,
|
||||
ext_redis,
|
||||
ext_storage,
|
||||
ext_logstore, # Initialize logstore after storage, before celery
|
||||
ext_celery,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
|
|
@ -114,6 +117,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_commands,
|
||||
ext_otel,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ from libs.rsa import generate_key_pair
|
|||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.enums import CreatorUserRole, ExecutionOffLoadType
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
|
|||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=300.0,
|
||||
default=600.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig
|
|||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
from .vdb.iris_config import IrisVectorConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.matrixone_config import MatrixoneConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
|
|
@ -106,7 +107,7 @@ class KeywordStoreConfig(BaseSettings):
|
|||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
|
||||
description="Database type to use. OceanBase is MySQL-compatible.",
|
||||
default="postgresql",
|
||||
)
|
||||
|
|
@ -336,6 +337,7 @@ class MiddlewareConfig(
|
|||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
IrisVectorConfig,
|
||||
MilvusConfig,
|
||||
AlibabaCloudMySQLConfig,
|
||||
MyScaleConfig,
|
||||
|
|
|
|||
|
|
@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
|
|||
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_CLOUDBOX_ID: str | None = Field(
|
||||
description="Cloudbox id for aliyun cloudbox service",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
|
|||
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_PATH_STYLE: bool = Field(
|
||||
description="Flag to indicate whether to use path-style URLs for OBS requests",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class IrisVectorConfig(BaseSettings):
|
||||
"""Configuration settings for IRIS vector database connection and pooling."""
|
||||
|
||||
IRIS_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the IRIS server.",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
|
||||
description="Port number for IRIS connection.",
|
||||
default=1972,
|
||||
)
|
||||
|
||||
IRIS_USER: str | None = Field(
|
||||
description="Username for IRIS authentication.",
|
||||
default="_SYSTEM",
|
||||
)
|
||||
|
||||
IRIS_PASSWORD: str | None = Field(
|
||||
description="Password for IRIS authentication.",
|
||||
default="Dify@1234",
|
||||
)
|
||||
|
||||
IRIS_SCHEMA: str | None = Field(
|
||||
description="Schema name for IRIS tables.",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
IRIS_DATABASE: str | None = Field(
|
||||
description="Database namespace for IRIS connection.",
|
||||
default="USER",
|
||||
)
|
||||
|
||||
IRIS_CONNECTION_URL: str | None = Field(
|
||||
description="Full connection URL for IRIS (overrides individual fields if provided).",
|
||||
default=None,
|
||||
)
|
||||
|
||||
IRIS_MIN_CONNECTION: PositiveInt = Field(
|
||||
description="Minimum number of connections in the pool.",
|
||||
default=1,
|
||||
)
|
||||
|
||||
IRIS_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Maximum number of connections in the pool.",
|
||||
default=3,
|
||||
)
|
||||
|
||||
IRIS_TEXT_INDEX: bool = Field(
|
||||
description="Enable full-text search index using %iFind.Index.Basic.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
|
||||
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
|
||||
default="en",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
values: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Validated configuration dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing or pool settings are invalid
|
||||
"""
|
||||
# Only validate required fields if IRIS is being used as the vector store
|
||||
# This allows the config to be loaded even when IRIS is not in use
|
||||
|
||||
# vector_store = os.environ.get("VECTOR_STORE", "")
|
||||
# We rely on Pydantic defaults for required fields if they are missing from env.
|
||||
# Strict existence check is removed to allow defaults to work.
|
||||
|
||||
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
|
||||
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
|
||||
if min_conn > max_conn:
|
||||
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
|
||||
|
||||
return values
|
||||
|
|
@ -20,6 +20,7 @@ language_timezone_mapping = {
|
|||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
"ar-TN": "Africa/Tunis",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
|
|
|||
|
|
@ -6,19 +6,20 @@ from flask import request
|
|||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ class InsertExploreAppListApi(Resource):
|
|||
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
|
|
@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource):
|
|||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
|
@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource):
|
|||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Literal
|
||||
|
||||
from flask import abort, request
|
||||
from flask import abort, make_response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
|
@ -272,7 +272,6 @@ class AnnotationExportApi(Resource):
|
|||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
from flask import make_response
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
|
|
@ -340,9 +339,9 @@ class AnnotationBatchImportApi(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
|
@ -352,27 +351,26 @@ class AnnotationBatchImportApi(Resource):
|
|||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
|
||||
# check file type
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
|
||||
# Check file size before processing
|
||||
file.seek(0, 2) # Seek to end of file
|
||||
file_size = file.tell()
|
||||
file.seek(0) # Reset to beginning
|
||||
|
||||
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
if file_size > max_size_bytes:
|
||||
abort(
|
||||
413,
|
||||
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
|
||||
f"Please reduce the file size and try again."
|
||||
f"Please reduce the file size and try again.",
|
||||
)
|
||||
|
||||
|
||||
if file_size == 0:
|
||||
raise ValueError("The uploaded file is empty")
|
||||
|
||||
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ from controllers.console import console_ns
|
|||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
from services.account_service import RegisterService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
|
@ -93,7 +93,6 @@ class ActivateApi(Resource):
|
|||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
"data": fields.Raw(description="Login token data"),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
|
@ -117,6 +116,4 @@ class ActivateApi(Resource):
|
|||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,12 @@ from controllers.console.error import (
|
|||
NotAllowedCreateWorkspace,
|
||||
WorkspacesLimitExceeded,
|
||||
)
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from controllers.console.wraps import (
|
||||
decrypt_code_field,
|
||||
decrypt_password_field,
|
||||
email_password_login_enabled,
|
||||
setup_required,
|
||||
)
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
|
|
@ -79,6 +84,7 @@ class LoginApi(Resource):
|
|||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||
@decrypt_password_field
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
|
|
@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
|
|
|
|||
|
|
@ -140,6 +140,18 @@ class DataSourceNotionListApi(Resource):
|
|||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
|
||||
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
|
||||
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
|
||||
datasource_parameters = {}
|
||||
if datasource_parameters_str:
|
||||
try:
|
||||
datasource_parameters = json.loads(datasource_parameters_str)
|
||||
if not isinstance(datasource_parameters, dict):
|
||||
raise ValueError("datasource_parameters must be a JSON object.")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid datasource_parameters JSON format.")
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
|
|
@ -187,7 +199,7 @@ class DataSourceNotionListApi(Resource):
|
|||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters={},
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
)
|
||||
|
|
@ -218,14 +230,14 @@ class DataSourceNotionListApi(Resource):
|
|||
|
||||
|
||||
@console_ns.route(
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
class DataSourceNotionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
def get(self, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
|
|
@ -239,11 +251,10 @@ class DataSourceNotionApi(Resource):
|
|||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_workspace_id="",
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
|
|||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
|
@ -223,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
|||
VectorType.COUCHBASE,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.SEEKDB,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
|
|
@ -230,6 +231,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
|||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
VectorType.IRIS,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Literal, cast
|
|||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
|
@ -968,7 +968,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||
)
|
||||
return workflow_node_execution
|
||||
|
||||
from flask_restx import reqparse
|
||||
|
||||
@console_ns.route("/rag/pipelines/recommended-plugins")
|
||||
class RagPipelineRecommendedPluginApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -976,7 +976,7 @@ class RagPipelineRecommendedPluginApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('type', type=str, location='args', required=False, default='all')
|
||||
parser.add_argument("type", type=str, location="args", required=False, default="all")
|
||||
args = parser.parse_args()
|
||||
type = args["type"]
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from .. import console_ns
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
class CompletionMessageExplorePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
|
@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
|
|||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
|
|
@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
|||
endpoint="installed_app_completion",
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
|
||||
payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
|
|
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
|
|
@ -24,7 +24,7 @@ from .. import console_ns
|
|||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
pinned: bool | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
|
|
@ -18,6 +19,15 @@ from services.account_service import TenantService
|
|||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class InstalledAppCreatePayload(BaseModel):
|
||||
app_id: str
|
||||
|
||||
|
||||
class InstalledAppUpdatePayload(BaseModel):
|
||||
is_pinned: bool | None = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||
args = parser.parse_args()
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||
if recommended_app is None:
|
||||
raise NotFound("App not found")
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||
|
||||
if app is None:
|
||||
raise NotFound("App not found")
|
||||
raise NotFound("App entity not found")
|
||||
|
||||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
|
||||
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
|
|||
recommended_app.install_count += 1
|
||||
|
||||
new_installed_app = InstalledApp(
|
||||
app_id=args["app_id"],
|
||||
app_id=payload.app_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
is_pinned=False,
|
||||
|
|
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
|
|||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||
|
||||
def patch(self, installed_app):
|
||||
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
|
||||
args = parser.parse_args()
|
||||
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
commit_args = False
|
||||
if "is_pinned" in args:
|
||||
installed_app.is_pinned = args["is_pinned"]
|
||||
if payload.is_pinned is not None:
|
||||
installed_app.is_pinned = payload.is_pinned
|
||||
commit_args = True
|
||||
|
||||
if commit_args:
|
||||
|
|
|
|||
|
|
@ -1,31 +1,40 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
parser_tags = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||
target_id: str = Field(description="Target ID to bind tags to")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -43,7 +52,7 @@ class TagListApi(Resource):
|
|||
|
||||
return tags, 200
|
||||
|
||||
@console_ns.expect(parser_tags)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -53,22 +62,17 @@ class TagListApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_tags.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
parser_tag_id = reqparse.RequestParser().add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(parser_tag_id)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_tag_id.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
|
|
@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource):
|
|||
return 204
|
||||
|
||||
|
||||
parser_create = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(parser_create)
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_create.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_remove = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(parser_remove)
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_remove.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ class PluginDebuggingKeyApi(Resource):
|
|||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
page: int = Field(default=1)
|
||||
page_size: int = Field(default=256)
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
|
||||
|
||||
|
||||
reg(ParserList)
|
||||
|
|
@ -106,8 +106,8 @@ class ParserPluginIdentifierQuery(BaseModel):
|
|||
|
||||
|
||||
class ParserTasks(BaseModel):
|
||||
page: int
|
||||
page_size: int
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
|
||||
|
||||
|
||||
class ParserMarketplaceUpgrade(BaseModel):
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
|
|
@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider
|
||||
# Create provider in transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
|
|
@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@console_ns.expect(parser_mcp_put)
|
||||
@setup_required
|
||||
|
|
@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||
validation_result = None
|
||||
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||
validation_data = None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_result = service.validate_server_url_change(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||
validation_data = service.get_provider_for_url_validation(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"]
|
||||
)
|
||||
|
||||
# No need to check for errors here, exceptions will be raised directly
|
||||
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
|
||||
# This prevents holding database locks during potentially slow network operations
|
||||
validation_result = MCPToolManageService.validate_server_url_standalone(
|
||||
tenant_id=current_tenant_id,
|
||||
new_server_url=args["server_url"],
|
||||
validation_data=validation_data,
|
||||
)
|
||||
|
||||
# Step 2: Perform database update in a transaction
|
||||
# Step 3: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
|
|
@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_mcp_delete)
|
||||
@setup_required
|
||||
|
|
@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource):
|
|||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_auth = (
|
||||
|
|
@ -1062,6 +1081,8 @@ class ToolMCPAuthApi(Resource):
|
|||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
# Invalidate cache after updating credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
|
|
@ -1075,16 +1096,22 @@ class ToolMCPAuthApi(Resource):
|
|||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
# Invalidate cache after auth actions may have updated provider state
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar
|
|||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.encryption import FieldEncryption
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
|
|
@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo
|
|||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Field names for decryption
|
||||
FIELD_NAME_PASSWORD = "password"
|
||||
FIELD_NAME_CODE = "code"
|
||||
|
||||
# Error messages for decryption failures
|
||||
ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
|
|
@ -336,84 +346,155 @@ def is_admin_or_owner_required(f: Callable[P, R]):
|
|||
def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
"""
|
||||
Rate limiting decorator for annotation import operations.
|
||||
|
||||
|
||||
Implements sliding window rate limiting with two tiers:
|
||||
- Short-term: Configurable requests per minute (default: 5)
|
||||
- Long-term: Configurable requests per hour (default: 20)
|
||||
|
||||
|
||||
Uses Redis ZSET for distributed rate limiting across multiple instances.
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_time = int(time.time() * 1000)
|
||||
|
||||
# Check per-minute rate limit
|
||||
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
|
||||
redis_client.zadd(minute_key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
|
||||
minute_count = redis_client.zcard(minute_key)
|
||||
redis_client.expire(minute_key, 120) # 2 minutes TTL
|
||||
|
||||
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
|
||||
abort(
|
||||
429,
|
||||
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} "
|
||||
f"requests per minute allowed. Please try again later."
|
||||
f"requests per minute allowed. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
# Check per-hour rate limit
|
||||
hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour"
|
||||
redis_client.zadd(hour_key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
|
||||
hour_count = redis_client.zcard(hour_key)
|
||||
redis_client.expire(hour_key, 7200) # 2 hours TTL
|
||||
|
||||
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
|
||||
abort(
|
||||
429,
|
||||
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} "
|
||||
f"requests per hour allowed. Please try again later."
|
||||
f"requests per hour allowed. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def annotation_import_concurrency_limit(view: Callable[P, R]):
|
||||
"""
|
||||
Concurrency control decorator for annotation import operations.
|
||||
|
||||
|
||||
Limits the number of concurrent import tasks per tenant to prevent
|
||||
resource exhaustion and ensure fair resource allocation.
|
||||
|
||||
|
||||
Uses Redis ZSET to track active import jobs with automatic cleanup
|
||||
of stale entries (jobs older than 2 minutes).
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_time = int(time.time() * 1000)
|
||||
|
||||
|
||||
active_jobs_key = f"annotation_import_active:{current_tenant_id}"
|
||||
|
||||
|
||||
# Clean up stale entries (jobs that should have completed or timed out)
|
||||
stale_threshold = current_time - 120000 # 2 minutes ago
|
||||
redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold)
|
||||
|
||||
|
||||
# Check current active job count
|
||||
active_count = redis_client.zcard(active_jobs_key)
|
||||
|
||||
|
||||
if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT:
|
||||
abort(
|
||||
429,
|
||||
f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} "
|
||||
f"concurrent imports allowed per workspace. Please wait for existing imports to complete."
|
||||
f"concurrent imports allowed per workspace. Please wait for existing imports to complete.",
|
||||
)
|
||||
|
||||
|
||||
# Allow the request to proceed
|
||||
# The actual job registration will happen in the service layer
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None:
|
||||
"""
|
||||
Helper to decode a Base64 encoded field in the request payload.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field to decode
|
||||
error_class: Exception class to raise on decoding failure
|
||||
error_message: Error message to include in the exception
|
||||
"""
|
||||
if not request or not request.is_json:
|
||||
return
|
||||
# Get the payload dict - it's cached and mutable
|
||||
payload = request.get_json()
|
||||
if not payload or field_name not in payload:
|
||||
return
|
||||
encoded_value = payload[field_name]
|
||||
decoded_value = FieldEncryption.decrypt_field(encoded_value)
|
||||
|
||||
# If decoding failed, raise error immediately
|
||||
if decoded_value is None:
|
||||
raise error_class(error_message)
|
||||
|
||||
# Update payload dict in-place with decoded value
|
||||
# Since payload is a mutable dict and get_json() returns the cached reference,
|
||||
# modifying it will affect all subsequent accesses including console_ns.payload
|
||||
payload[field_name] = decoded_value
|
||||
|
||||
|
||||
def decrypt_password_field(view: Callable[P, R]):
|
||||
"""
|
||||
Decorator to decrypt password field in request payload.
|
||||
|
||||
Automatically decrypts the 'password' field if encryption is enabled.
|
||||
If decryption fails, raises AuthenticationFailedError.
|
||||
|
||||
Usage:
|
||||
@decrypt_password_field
|
||||
def post(self):
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
# args.password is now decrypted
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def decrypt_code_field(view: Callable[P, R]):
|
||||
"""
|
||||
Decorator to decrypt verification code field in request payload.
|
||||
|
||||
Automatically decrypts the 'code' field if encryption is enabled.
|
||||
If decryption fails, raises EmailCodeError.
|
||||
|
||||
Usage:
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
# args.code is now decrypted
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
|
|
|||
|
|
@ -61,6 +61,9 @@ class ChatRequestPayload(BaseModel):
|
|||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
|
||||
if not value:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
|
|
@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel):
|
|||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
variable_name: str | None = Field(
|
||||
default=None, description="Filter variables by name", min_length=1, max_length=255
|
||||
)
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def validate_variable_name(cls, v: str | None) -> str | None:
|
||||
"""
|
||||
Validate variable_name to prevent injection attacks.
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# Only allow safe characters: alphanumeric, underscore, hyphen, period
|
||||
if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
|
||||
raise ValueError(
|
||||
"Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
|
||||
)
|
||||
|
||||
# Prevent SQL injection patterns
|
||||
dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in v.lower():
|
||||
raise ValueError(f"Variable name contains invalid characters: {pattern}")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
|
|
@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource):
|
|||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel):
|
|||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str):
|
|||
if response:
|
||||
break
|
||||
if not response:
|
||||
logger.error("Endpoint not found for {endpoint_id}")
|
||||
logger.info("Endpoint not found for %s", endpoint_id)
|
||||
return jsonify({"error": "Endpoint not found"}), 404
|
||||
return response
|
||||
except ValueError as e:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
from controllers.common import fields
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from controllers.common.schema import register_schema_models
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
|
|
@ -18,9 +17,23 @@ from services.enterprise.enterprise_service import EnterpriseService
|
|||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
from . import web_ns
|
||||
from .error import AppUnavailableError
|
||||
from .wraps import WebApiResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppAccessModeQuery(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
app_id: str | None = Field(default=None, alias="appId", description="Application ID")
|
||||
app_code: str | None = Field(default=None, alias="appCode", description="Application code")
|
||||
|
||||
|
||||
register_schema_models(web_ns, AppAccessModeQuery)
|
||||
|
||||
|
||||
@web_ns.route("/parameters")
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
|
@ -96,21 +109,16 @@ class AppAccessMode(Resource):
|
|||
}
|
||||
)
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("appId", type=str, required=False, location="args")
|
||||
.add_argument("appCode", type=str, required=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
raw_args = request.args.to_dict()
|
||||
args = AppAccessModeQuery.model_validate(raw_args)
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.webapp_auth.enabled:
|
||||
return {"accessMode": "public"}
|
||||
|
||||
app_id = args.get("appId")
|
||||
if args.get("appCode"):
|
||||
app_code = args["appCode"]
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
app_id = args.app_id
|
||||
if args.app_code:
|
||||
app_id = AppService.get_app_id_by_code(args.app_code)
|
||||
|
||||
if not app_id:
|
||||
raise ValueError("appId or appCode must be provided")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, field_validator
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
|
|
@ -20,6 +21,7 @@ from controllers.web.error import (
|
|||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
|
|
@ -29,6 +31,25 @@ from services.errors.audio import (
|
|||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, TextToAudioPayload)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -88,6 +109,7 @@ class AudioApi(WebApiResource):
|
|||
|
||||
@web_ns.route("/text-to-audio")
|
||||
class TextApi(WebApiResource):
|
||||
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
|
||||
@web_ns.doc("Text to Audio")
|
||||
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@web_ns.doc(
|
||||
|
|
@ -102,18 +124,11 @@ class TextApi(WebApiResource):
|
|||
def post(self, app_model: App, end_user):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = payload.message_id
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
|
|
@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the completion")
|
||||
query: str = Field(default="", description="Query text for completion")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
|
||||
response_mode: Literal["blocking", "streaming"] | None = Field(
|
||||
default=None, description="Response mode: blocking or streaming"
|
||||
)
|
||||
retriever_from: str = Field(default="web_app", description="Source of retriever")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the chat")
|
||||
query: str = Field(description="User query/message")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
|
||||
response_mode: Literal["blocking", "streaming"] | None = Field(
|
||||
default=None, description="Response mode: blocking or streaming"
|
||||
)
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
retriever_from: str = Field(default="web_app", description="Source of retriever")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@web_ns.route("/completion-messages")
|
||||
class CompletionApi(WebApiResource):
|
||||
@web_ns.doc("Create Completion Message")
|
||||
@web_ns.doc(description="Create a completion message for text generation applications.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
|
||||
"query": {"description": "Query text for completion", "type": "string", "required": False},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
|
|
@ -67,18 +88,10 @@ class CompletionApi(WebApiResource):
|
|||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
)
|
||||
payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource):
|
|||
class ChatApi(WebApiResource):
|
||||
@web_ns.doc("Create Chat Message")
|
||||
@web_ns.doc(description="Create a chat message for conversational applications.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
|
||||
"query": {"description": "User query/message", "type": "string", "required": True},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
|
||||
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
|
|
@ -173,20 +171,10 @@ class ChatApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
)
|
||||
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
|
|
@ -38,6 +41,33 @@ from services.message_service import MessageService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: str = Field(description="Conversation UUID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class MessageMoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"] = Field(
|
||||
description="Response mode",
|
||||
)
|
||||
|
||||
|
||||
register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery)
|
||||
|
||||
|
||||
@web_ns.route("/messages")
|
||||
class MessageListApi(WebApiResource):
|
||||
message_fields = {
|
||||
|
|
@ -68,7 +98,11 @@ class MessageListApi(WebApiResource):
|
|||
@web_ns.doc(
|
||||
params={
|
||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
|
||||
"first_id": {"description": "First message ID for pagination", "type": "string", "required": False},
|
||||
"first_id": {
|
||||
"description": "First message ID for pagination",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
},
|
||||
"limit": {
|
||||
"description": "Number of messages to return (1-100)",
|
||||
"type": "integer",
|
||||
|
|
@ -93,17 +127,12 @@ class MessageListApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
.add_argument("first_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
raw_args = request.args.to_dict()
|
||||
query = MessageListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, end_user, query.conversation_id, query.first_id, query.limit
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
@ -128,7 +157,7 @@ class MessageFeedbackApi(WebApiResource):
|
|||
"enum": ["like", "dislike"],
|
||||
"required": False,
|
||||
},
|
||||
"content": {"description": "Feedback content/comment", "type": "string", "required": False},
|
||||
"content": {"description": "Feedback content", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
|
|
@ -145,20 +174,15 @@ class MessageFeedbackApi(WebApiResource):
|
|||
def post(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
.add_argument("content", type=str, location="json", default=None)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
rating=payload.rating,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
@ -170,17 +194,7 @@ class MessageFeedbackApi(WebApiResource):
|
|||
class MessageMoreLikeThisApi(WebApiResource):
|
||||
@web_ns.doc("Generate More Like This")
|
||||
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"message_id": {"description": "Message UUID", "type": "string", "required": True},
|
||||
"response_mode": {
|
||||
"description": "Response mode",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
|
|
@ -197,12 +211,10 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
raw_args = request.args.to_dict()
|
||||
query = MessageMoreLikeThisQuery.model_validate(raw_args)
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = query.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_more_like_this(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
|
@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
|
|||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict[str, Any] | None = Field(default=None)
|
||||
json_schema: str | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
|
|||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
def validate_json_schema(cls, schema: str | None) -> str | None:
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"invalid json_schema value {schema}")
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(json_schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
|
|
@ -104,8 +105,9 @@ class BaseAppGenerator:
|
|||
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
|
||||
and not variable_entity.required
|
||||
):
|
||||
# Treat empty string (frontend default) or empty list as unset
|
||||
if not value and isinstance(value, (str, list)):
|
||||
# Treat empty string (frontend default) as unset
|
||||
# For FILE_LIST, allow empty list [] to pass through
|
||||
if isinstance(value, str) and not value:
|
||||
return None
|
||||
|
||||
if variable_entity.type in {
|
||||
|
|
@ -175,6 +177,13 @@ class BaseAppGenerator:
|
|||
value = True
|
||||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a string")
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
|
|
|||
|
|
@ -342,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=event_type,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from threading import Thread
|
|||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -54,6 +54,20 @@ class MessageCycleManager:
|
|||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
self._message_has_file: set[str] = set()
|
||||
|
||||
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
||||
if message_id in self._message_has_file:
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
|
||||
|
||||
if has_file:
|
||||
self._message_has_file.add(message_id)
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
return StreamEvent.MESSAGE
|
||||
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
|
||||
"""
|
||||
|
|
@ -214,7 +228,11 @@ class MessageCycleManager:
|
|||
return None
|
||||
|
||||
def message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
|
||||
self,
|
||||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
|
|
@ -222,16 +240,12 @@ class MessageCycleManager:
|
|||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
event=event_type,
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
_session_maker: sessionmaker | None = None
|
||||
|
||||
|
||||
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
"""Configure the global session factory"""
|
||||
global _session_maker
|
||||
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||
|
||||
|
||||
def get_session_maker() -> sessionmaker:
|
||||
if _session_maker is None:
|
||||
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
|
||||
return _session_maker
|
||||
|
||||
|
||||
def create_session() -> Session:
|
||||
return get_session_maker()()
|
||||
|
||||
|
||||
# Class wrapper for convenience
|
||||
class SessionFactory:
|
||||
@staticmethod
|
||||
def configure(engine: Engine, expire_on_commit: bool = False):
|
||||
configure_session_factory(engine, expire_on_commit)
|
||||
|
||||
@staticmethod
|
||||
def get_session_maker() -> sessionmaker:
|
||||
return get_session_maker()
|
||||
|
||||
@staticmethod
|
||||
def create_session() -> Session:
|
||||
return create_session()
|
||||
|
||||
|
||||
session_factory = SessionFactory()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
|
|
@ -20,9 +20,17 @@ class IndexingEstimate(BaseModel):
|
|||
class PipelineDataset(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = Field(default="", description="knowledge dataset description")
|
||||
description: str = Field(default="", description="knowledge dataset description")
|
||||
chunk_structure: str
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def normalize_description(cls, value: str | None) -> str:
|
||||
"""Coerce None to empty string so description is always a string."""
|
||||
if value is None:
|
||||
return ""
|
||||
return value
|
||||
|
||||
|
||||
class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
|
|
|
|||
|
|
@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel):
|
|||
return None
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
"""OAuth tokens if available"""
|
||||
"""Retrieve OAuth tokens if authentication is complete.
|
||||
|
||||
Returns:
|
||||
OAuthTokens if the provider has been authenticated, None otherwise.
|
||||
"""
|
||||
if not self.credentials:
|
||||
return None
|
||||
credentials = self.decrypt_credentials()
|
||||
access_token = credentials.get("access_token", "")
|
||||
# Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header.
|
||||
# Note: We don't check for whitespace-only strings here because:
|
||||
# 1. OAuth servers don't return whitespace-only access tokens in practice
|
||||
# 2. Even if they did, the server would return 401, triggering the OAuth flow correctly
|
||||
if not access_token:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
access_token=access_token,
|
||||
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
|
|
|
|||
|
|
@ -72,15 +72,22 @@ class LLMGenerator:
|
|||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
if answer is None:
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
result_dict = json.loads(answer)
|
||||
except json.JSONDecodeError:
|
||||
logger.exception("Failed to generate name after answer, use query instead")
|
||||
result_dict = json_repair.loads(answer)
|
||||
|
||||
if not isinstance(result_dict, dict):
|
||||
answer = query
|
||||
else:
|
||||
output = result_dict.get("Your Output")
|
||||
if isinstance(output, str) and output.strip():
|
||||
answer = output.strip()
|
||||
else:
|
||||
answer = query
|
||||
|
||||
name = answer.strip()
|
||||
|
||||
if len(name) > 75:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
|
|||
"""
|
||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
|
||||
Priority order:
|
||||
1. URL from WWW-Authenticate header (if provided)
|
||||
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
|
||||
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
|
||||
"""
|
||||
urls = []
|
||||
|
||||
|
|
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
|
|||
# Fallback: construct from server URL
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
if fallback_url not in urls:
|
||||
urls.append(fallback_url)
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
|
||||
if path:
|
||||
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
|
||||
if path_url not in urls:
|
||||
urls.append(path_url)
|
||||
|
||||
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
|
||||
root_url = f"{base_url}/.well-known/oauth-protected-resource"
|
||||
if root_url not in urls:
|
||||
urls.append(root_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
|
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
|
|||
|
||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||
|
||||
Per RFC 8414 section 3:
|
||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
||||
|
||||
Example:
|
||||
- issuer: https://example.com/oauth
|
||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
||||
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
|
||||
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
|
||||
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
|
||||
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
|
||||
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
|
||||
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
|
||||
"""
|
||||
urls = []
|
||||
base_url = auth_server_url or server_url
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
||||
path = parsed.path.rstrip("/")
|
||||
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
|
||||
urls.append(f"{base}/.well-known/oauth-authorization-server")
|
||||
|
||||
# Try OpenID Connect discovery first (more common)
|
||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
||||
# OpenID Connect Discovery at root
|
||||
urls.append(f"{base}/.well-known/openid-configuration")
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
||||
# Include the path component if present in the issuer URL
|
||||
if path:
|
||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
||||
else:
|
||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
||||
# OpenID Connect Discovery with path insertion
|
||||
urls.append(f"{base}/.well-known/openid-configuration{path}")
|
||||
|
||||
# OpenID Connect Discovery path appending
|
||||
urls.append(f"{base}{path}/.well-known/openid-configuration")
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata with path insertion
|
||||
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||
|
||||
return urls
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class MCPClient:
|
|||
try:
|
||||
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
except (MCPConnectionError, ValueError):
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
|
|
|
|||
|
|
@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model
|
|||
|
||||
- Model provider display
|
||||
|
||||

|
||||
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
|
||||
|
||||
- Selectable model list display
|
||||
|
||||

|
||||
|
||||
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
|
||||
|
||||
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
|
||||
|
||||

|
||||
|
||||
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
|
||||
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
|
||||
|
||||
- Provider/model credential authentication
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
|
||||
|
||||
## Structure
|
||||
|
||||

|
||||
|
||||
Model Runtime is divided into three layers:
|
||||
|
||||
- The outermost layer is the factory method
|
||||
|
|
@ -60,9 +46,6 @@ Model Runtime is divided into three layers:
|
|||
|
||||
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
||||
|
||||
## Next Steps
|
||||
## Documentation
|
||||
|
||||
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
|
||||
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
|
||||
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
|
||||
- Implement interface methods: [Link](./docs/en_US/interfaces.md)
|
||||
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).
|
||||
|
|
|
|||
|
|
@ -18,34 +18,20 @@
|
|||
|
||||
- 模型供应商展示
|
||||
|
||||

|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||

|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
|
||||
|
||||

|
||||
|
||||
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
|
||||
|
||||
- 供应商/模型凭据鉴权
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
|
||||
|
||||
## 结构
|
||||
|
||||

|
||||
|
||||
Model Runtime 分三层:
|
||||
|
||||
- 最外层为工厂方法
|
||||
|
|
@ -59,8 +45,7 @@ Model Runtime 分三层:
|
|||
对于供应商/模型凭据,有两种情况
|
||||
|
||||
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||

|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||
|
||||
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||
|
||||
|
|
@ -74,20 +59,6 @@ Model Runtime 分三层:
|
|||
|
||||
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
|
||||
|
||||
## 下一步
|
||||
## 文档
|
||||
|
||||
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
|
||||
|
||||
当添加后,这里将会出现一个新的供应商
|
||||
|
||||

|
||||
|
||||
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
|
||||
|
||||
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
|
||||
|
||||

|
||||
|
||||
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
|
||||
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。
|
||||
|
|
|
|||
|
|
@ -6,7 +6,13 @@ from datetime import datetime, timedelta
|
|||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
|
||||
from openinference.semconv.trace import (
|
||||
MessageAttributes,
|
||||
OpenInferenceMimeTypeValues,
|
||||
OpenInferenceSpanKindValues,
|
||||
SpanAttributes,
|
||||
ToolCallAttributes,
|
||||
)
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
|
|
@ -95,14 +101,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
|
|||
|
||||
|
||||
def datetime_to_nanos(dt: datetime | None) -> int:
|
||||
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
|
||||
"""Convert datetime to nanoseconds since epoch for Arize/Phoenix."""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def error_to_string(error: Exception | str | None) -> str:
|
||||
"""Convert an error to a string with traceback information."""
|
||||
"""Convert an error to a string with traceback information for Arize/Phoenix."""
|
||||
error_message = "Empty Stack Trace"
|
||||
if error:
|
||||
if isinstance(error, Exception):
|
||||
|
|
@ -114,7 +120,7 @@ def error_to_string(error: Exception | str | None) -> str:
|
|||
|
||||
|
||||
def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
"""Set the status of the current span based on the presence of an error."""
|
||||
"""Set the status of the current span based on the presence of an error for Arize/Phoenix."""
|
||||
if error:
|
||||
error_string = error_to_string(error)
|
||||
current_span.set_status(Status(StatusCode.ERROR, error_string))
|
||||
|
|
@ -138,10 +144,17 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
|
|||
|
||||
|
||||
def safe_json_dumps(obj: Any) -> str:
|
||||
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
|
||||
"""A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix."""
|
||||
return json.dumps(obj, default=str, ensure_ascii=False)
|
||||
|
||||
|
||||
def wrap_span_metadata(metadata, **kwargs):
|
||||
"""Add common metatada to all trace entity types for Arize/Phoenix."""
|
||||
metadata["created_from"] = "Dify"
|
||||
metadata.update(kwargs)
|
||||
return metadata
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -183,16 +196,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
workflow_metadata = {
|
||||
"workflow_run_id": trace_info.workflow_run_id or "",
|
||||
"message_id": trace_info.message_id or "",
|
||||
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
|
||||
"status": trace_info.workflow_run_status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.workflow_run_status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="workflow",
|
||||
conversation_id=trace_info.conversation_id or "",
|
||||
workflow_app_log_id=trace_info.workflow_app_log_id or "",
|
||||
workflow_id=trace_info.workflow_id or "",
|
||||
tenant_id=trace_info.tenant_id or "",
|
||||
workflow_run_id=trace_info.workflow_run_id or "",
|
||||
workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0,
|
||||
workflow_run_version=trace_info.workflow_run_version or "",
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
file_list=safe_json_dumps(file_list),
|
||||
query=trace_info.query or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -201,10 +225,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
|
|
@ -257,6 +283,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
"app_id": app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"status_message": node_execution.error or "",
|
||||
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
|
||||
}
|
||||
)
|
||||
|
|
@ -290,11 +317,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
|
|
@ -339,30 +366,37 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.")
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_metadata = {
|
||||
"message_id": trace_info.message_id or "",
|
||||
"conversation_mode": str(trace_info.conversation_mode or ""),
|
||||
"user_id": trace_info.message_data.from_account_id or "",
|
||||
"file_list": json.dumps(file_list),
|
||||
"status": trace_info.message_data.status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
"prompt_tokens": trace_info.message_tokens or 0,
|
||||
"completion_tokens": trace_info.answer_tokens or 0,
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
message_metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="message",
|
||||
conversation_model=trace_info.conversation_model or "",
|
||||
message_tokens=trace_info.message_tokens or 0,
|
||||
answer_tokens=trace_info.answer_tokens or 0,
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
conversation_mode=trace_info.conversation_mode or "",
|
||||
gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0,
|
||||
llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0,
|
||||
is_streaming_request=trace_info.is_streaming_request or False,
|
||||
user_id=trace_info.message_data.from_account_id or "",
|
||||
file_list=safe_json_dumps(file_list),
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
|
|
@ -370,14 +404,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
message_metadata["end_user_id"] = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_data.session_id
|
||||
|
||||
attributes = {
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
|
||||
}
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
|
|
@ -393,8 +429,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
try:
|
||||
# Convert outputs to string based on type
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
outputs_str = safe_json_dumps(trace_info.outputs)
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
|
|
@ -402,10 +440,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
llm_attributes = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: outputs_str,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
|
||||
}
|
||||
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
|
|
@ -449,16 +489,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "moderation",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.message_data.error or "",
|
||||
level="ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
trace_entity_type="moderation",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -467,18 +511,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(
|
||||
{
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"action": trace_info.action,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
"query": trace_info.query,
|
||||
}
|
||||
),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
|
|
@ -494,22 +539,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.")
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "suggested_question",
|
||||
"status": trace_info.status,
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
"ls_provider": trace_info.model_provider or "",
|
||||
"ls_model_name": trace_info.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.status or "",
|
||||
status_message=trace_info.status_message or "",
|
||||
level=trace_info.level or "",
|
||||
trace_entity_type="suggested_question",
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
from_account_id=trace_info.from_account_id or "",
|
||||
agent_based=trace_info.agent_based or False,
|
||||
from_source=trace_info.from_source or "",
|
||||
model_provider=trace_info.model_provider or "",
|
||||
model_id=trace_info.model_id or "",
|
||||
workflow_run_id=trace_info.workflow_run_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -518,10 +569,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=root_span_context,
|
||||
|
|
@ -537,21 +590,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.")
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "dataset_retrieval",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="dataset_retrieval",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -560,20 +615,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
"start_time": start_time.isoformat() if start_time else "",
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
if trace_info.error:
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
|
|
@ -584,30 +639,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="tool",
|
||||
tool_config=safe_json_dumps(trace_info.tool_config),
|
||||
time_cost=trace_info.time_cost or 0,
|
||||
file_url=trace_info.file_url or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
if isinstance(trace_info.tool_parameters, dict)
|
||||
else str(trace_info.tool_parameters)
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=trace_info.tool_name,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.TOOL_NAME: trace_info.tool_name,
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
|
|
@ -623,16 +682,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"project_name": self.project,
|
||||
"message_id": trace_info.message_id,
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.message_data.error or "",
|
||||
level="ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
trace_entity_type="generate_name",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
conversation_id=trace_info.conversation_id or "",
|
||||
tenant_id=trace_info.tenant_id,
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
|
|
@ -641,13 +706,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
|
|
@ -688,32 +753,85 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
"""Build a redirect URL that forwards the user to the correct project for Arize/Phoenix."""
|
||||
try:
|
||||
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
|
||||
return "https://app.arize.com/"
|
||||
else:
|
||||
return f"{self.arize_phoenix_config.endpoint}/projects/"
|
||||
project_name = self.arize_phoenix_config.project
|
||||
endpoint = self.arize_phoenix_config.endpoint.rstrip("/")
|
||||
|
||||
# Arize
|
||||
if isinstance(self.arize_phoenix_config, ArizeConfig):
|
||||
return f"https://app.arize.com/?redirect_project_name={project_name}"
|
||||
|
||||
# Phoenix
|
||||
return f"{endpoint}/projects/?redirect_project_name={project_name}"
|
||||
|
||||
except Exception as e:
|
||||
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
"""Helper method to construct LLM attributes with passed prompts."""
|
||||
attributes = {}
|
||||
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
|
||||
attributes: dict[str, str] = {}
|
||||
|
||||
def set_attribute(path: str, value: object) -> None:
|
||||
"""Store an attribute safely as a string."""
|
||||
if value is None:
|
||||
return
|
||||
try:
|
||||
if isinstance(value, (dict, list)):
|
||||
value = safe_json_dumps(value)
|
||||
attributes[path] = str(value)
|
||||
except Exception:
|
||||
attributes[path] = str(value)
|
||||
|
||||
def set_message_attribute(message_index: int, key: str, value: object) -> None:
|
||||
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
|
||||
set_attribute(path, value)
|
||||
|
||||
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
|
||||
"""Extract and assign tool call details safely."""
|
||||
if not tool_call:
|
||||
return
|
||||
|
||||
def safe_get(obj, key, default=None):
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(key, default)
|
||||
return getattr(obj, key, default)
|
||||
|
||||
function_obj = safe_get(tool_call, "function", {})
|
||||
function_name = safe_get(function_obj, "name", "")
|
||||
function_args = safe_get(function_obj, "arguments", {})
|
||||
call_id = safe_get(tool_call, "id", "")
|
||||
|
||||
base_path = (
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}."
|
||||
f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}"
|
||||
)
|
||||
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name)
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args)
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
|
||||
|
||||
# Handle list of messages
|
||||
if isinstance(prompts, list):
|
||||
for i, msg in enumerate(prompts):
|
||||
if isinstance(msg, dict):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(prompts, dict):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(prompts, str):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
for message_index, message in enumerate(prompts):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
|
||||
role = message.get("role", "user")
|
||||
content = message.get("text") or message.get("content") or ""
|
||||
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
|
||||
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_index, tool_call in enumerate(tool_calls):
|
||||
set_tool_call_attributes(message_index, tool_index, tool_call)
|
||||
|
||||
# Handle single dict or plain string prompt
|
||||
elif isinstance(prompts, (dict, str)):
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
|
||||
|
||||
return attributes
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from core.trigger.errors import (
|
|||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
_plugin_daemon_timeout_config = cast(
|
||||
float | httpx.Timeout | None,
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
|
||||
)
|
||||
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||
if _plugin_daemon_timeout_config is None:
|
||||
|
|
|
|||
|
|
@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
|
|||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||
|
||||
documents = []
|
||||
|
||||
segment_query_stmt = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
|
||||
segments = db.session.execute(segment_query_stmt).scalars().all()
|
||||
segment_map = {segment.index_node_id: segment for segment in segments}
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment_query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment = segment_query.first()
|
||||
segment = segment_map.get(chunk_index)
|
||||
|
||||
if segment:
|
||||
documents.append(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
|
|
@ -138,37 +139,47 @@ class RetrievalService:
|
|||
|
||||
@classmethod
|
||||
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
||||
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
|
||||
"""Deduplicate documents in O(n) while preserving first-seen order.
|
||||
|
||||
Rules:
|
||||
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
|
||||
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
|
||||
- For non-dify documents (or dify without doc_id): deduplicate by content key
|
||||
(provider, page_content), keeping the first occurrence.
|
||||
"""
|
||||
if not documents:
|
||||
return documents
|
||||
|
||||
unique_documents = []
|
||||
seen_doc_ids = set()
|
||||
# Map of dedup key -> chosen Document
|
||||
chosen: dict[tuple, Document] = {}
|
||||
# Preserve the order of first appearance of each dedup key
|
||||
order: list[tuple] = []
|
||||
|
||||
for document in documents:
|
||||
# For dify provider documents, use doc_id for deduplication
|
||||
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
|
||||
doc_id = document.metadata["doc_id"]
|
||||
if doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
unique_documents.append(document)
|
||||
# If duplicate, keep the one with higher score
|
||||
elif "score" in document.metadata:
|
||||
# Find existing document with same doc_id and compare scores
|
||||
for i, existing_doc in enumerate(unique_documents):
|
||||
if (
|
||||
existing_doc.metadata
|
||||
and existing_doc.metadata.get("doc_id") == doc_id
|
||||
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
|
||||
):
|
||||
unique_documents[i] = document
|
||||
break
|
||||
for doc in documents:
|
||||
is_dify = doc.provider == "dify"
|
||||
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
|
||||
|
||||
if is_dify and doc_id:
|
||||
key = ("dify", doc_id)
|
||||
if key not in chosen:
|
||||
chosen[key] = doc
|
||||
order.append(key)
|
||||
else:
|
||||
# Only replace if the new one has a score and it's strictly higher
|
||||
if "score" in doc.metadata:
|
||||
new_score = float(doc.metadata.get("score", 0.0))
|
||||
old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
|
||||
if new_score > old_score:
|
||||
chosen[key] = doc
|
||||
else:
|
||||
# For non-dify documents, use content-based deduplication
|
||||
if document not in unique_documents:
|
||||
unique_documents.append(document)
|
||||
# Content-based dedup for non-dify or dify without doc_id
|
||||
content_key = (doc.provider or "dify", doc.page_content)
|
||||
if content_key not in chosen:
|
||||
chosen[content_key] = doc
|
||||
order.append(content_key)
|
||||
# If duplicate content appears, we keep the first occurrence (no score comparison)
|
||||
|
||||
return unique_documents
|
||||
return [chosen[k] for k in order]
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||
|
|
@ -371,58 +382,96 @@ class RetrievalService:
|
|||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Process documents
|
||||
for document in documents:
|
||||
segment_id = None
|
||||
attachment_info = None
|
||||
child_chunk = None
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids = []
|
||||
child_index_node_ids = []
|
||||
index_node_ids = []
|
||||
doc_to_document_map = {}
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
attachment_info_dict = cls.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
else:
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
valid_dataset_documents[document_id] = dataset_document
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
segment_id = child_chunk.segment_id
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
child_index_node_ids.append(doc_id)
|
||||
else:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
index_node_ids.append(doc_id)
|
||||
|
||||
if not segment_id:
|
||||
continue
|
||||
image_doc_ids = [i for i in image_doc_ids if i]
|
||||
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||
index_node_ids = [i for i in index_node_ids if i]
|
||||
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
segment_ids = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map = {}
|
||||
child_chunk_map = {}
|
||||
doc_segment_map = {}
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
|
||||
for attachment in attachments:
|
||||
segment_ids.append(attachment["segment_id"])
|
||||
attachment_map[attachment["segment_id"]] = attachment
|
||||
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
|
||||
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
for i in child_index_nodes:
|
||||
segment_ids.append(i.segment_id)
|
||||
child_chunk_map[i.segment_id] = i
|
||||
doc_segment_map[i.segment_id] = i.index_node_id
|
||||
|
||||
if index_node_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
)
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
for index_node_segment in index_node_segments:
|
||||
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
|
||||
if segment_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
)
|
||||
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
|
||||
if index_node_segments:
|
||||
segments.extend(index_node_segments)
|
||||
|
||||
for segment in segments:
|
||||
doc_id = doc_segment_map.get(segment.id)
|
||||
child_chunk = child_chunk_map.get(segment.id)
|
||||
attachment_info = attachment_map.get(segment.id)
|
||||
|
||||
if doc_id:
|
||||
document = doc_to_document_map[doc_id]
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
|
||||
document.metadata.get("document_id")
|
||||
)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
if child_chunk:
|
||||
|
|
@ -430,10 +479,10 @@ class RetrievalService:
|
|||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
|
|
@ -452,13 +501,14 @@ class RetrievalService:
|
|||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
if segment.id in segment_child_map:
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
segment_child_map[segment.id]["max_score"],
|
||||
document.metadata.get("score", 0.0) if document else 0.0,
|
||||
)
|
||||
else:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
if attachment_info:
|
||||
|
|
@ -467,46 +517,11 @@ class RetrievalService:
|
|||
else:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
# Handle normal documents
|
||||
segment = None
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
attachment_info_dict = cls.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
if segment:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
"score": document.metadata.get("score", 0.0), # type: ignore
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
|
|
@ -522,7 +537,7 @@ class RetrievalService:
|
|||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
if record["segment"].id in segment_file_map:
|
||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
||||
|
||||
|
|
@ -565,6 +580,8 @@ class RetrievalService:
|
|||
flask_app: Flask,
|
||||
retrieval_method: RetrievalMethod,
|
||||
dataset: Dataset,
|
||||
all_documents: list[Document],
|
||||
exceptions: list[str],
|
||||
query: str | None = None,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
|
|
@ -573,8 +590,6 @@ class RetrievalService:
|
|||
weights: dict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_id: str | None = None,
|
||||
all_documents: list[Document] = [],
|
||||
exceptions: list[str] = [],
|
||||
):
|
||||
if not query and not attachment_id:
|
||||
return
|
||||
|
|
@ -696,3 +711,37 @@ class RetrievalService:
|
|||
}
|
||||
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||
attachment_infos = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
attachment_bindings = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||
.all()
|
||||
)
|
||||
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
|
||||
|
||||
if attachment_bindings:
|
||||
for upload_file in upload_files:
|
||||
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||
attachment_info = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||
"size": upload_file.size,
|
||||
}
|
||||
if attachment_binding:
|
||||
attachment_infos.append(
|
||||
{
|
||||
"attachment_id": attachment_binding.attachment_id,
|
||||
"attachment_info": attachment_info,
|
||||
"segment_id": attachment_binding.segment_id,
|
||||
}
|
||||
)
|
||||
return attachment_infos
|
||||
|
|
|
|||
|
|
@ -0,0 +1,407 @@
|
|||
"""InterSystems IRIS vector database implementation for Dify.
|
||||
|
||||
This module provides vector storage and retrieval using IRIS native VECTOR type
|
||||
with HNSW indexing for efficient similarity search.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.iris_config import IrisVectorConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import iris
|
||||
else:
|
||||
try:
|
||||
import iris
|
||||
except ImportError:
|
||||
iris = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Singleton connection pool to minimize IRIS license usage
|
||||
_pool_lock = threading.Lock()
|
||||
_pool_instance: IrisConnectionPool | None = None
|
||||
|
||||
|
||||
def get_iris_pool(config: IrisVectorConfig) -> IrisConnectionPool:
|
||||
"""Get or create the global IRIS connection pool (singleton pattern)."""
|
||||
global _pool_instance # pylint: disable=global-statement
|
||||
with _pool_lock:
|
||||
if _pool_instance is None:
|
||||
logger.info("Initializing IRIS connection pool")
|
||||
_pool_instance = IrisConnectionPool(config)
|
||||
return _pool_instance
|
||||
|
||||
|
||||
class IrisConnectionPool:
|
||||
"""Thread-safe connection pool for IRIS database."""
|
||||
|
||||
def __init__(self, config: IrisVectorConfig) -> None:
|
||||
self.config = config
|
||||
self._pool: list[Any] = []
|
||||
self._lock = threading.Lock()
|
||||
self._min_size = config.IRIS_MIN_CONNECTION
|
||||
self._max_size = config.IRIS_MAX_CONNECTION
|
||||
self._in_use = 0
|
||||
self._schemas_initialized: set[str] = set() # Cache for initialized schemas
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self) -> None:
|
||||
for _ in range(self._min_size):
|
||||
self._pool.append(self._create_connection())
|
||||
|
||||
def _create_connection(self) -> Any:
|
||||
return iris.connect(
|
||||
hostname=self.config.IRIS_HOST,
|
||||
port=self.config.IRIS_SUPER_SERVER_PORT,
|
||||
namespace=self.config.IRIS_DATABASE,
|
||||
username=self.config.IRIS_USER,
|
||||
password=self.config.IRIS_PASSWORD,
|
||||
)
|
||||
|
||||
def get_connection(self) -> Any:
|
||||
"""Get a connection from pool or create new if available."""
|
||||
with self._lock:
|
||||
if self._pool:
|
||||
conn = self._pool.pop()
|
||||
self._in_use += 1
|
||||
return conn
|
||||
if self._in_use < self._max_size:
|
||||
conn = self._create_connection()
|
||||
self._in_use += 1
|
||||
return conn
|
||||
raise RuntimeError("Connection pool exhausted")
|
||||
|
||||
def return_connection(self, conn: Any) -> None:
|
||||
"""Return connection to pool after validating it."""
|
||||
if not conn:
|
||||
return
|
||||
|
||||
# Validate connection health
|
||||
is_valid = False
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.close()
|
||||
is_valid = True
|
||||
except (OSError, RuntimeError) as e:
|
||||
logger.debug("Connection validation failed: %s", e)
|
||||
try:
|
||||
conn.close()
|
||||
except (OSError, RuntimeError):
|
||||
pass
|
||||
|
||||
with self._lock:
|
||||
self._pool.append(conn if is_valid else self._create_connection())
|
||||
self._in_use -= 1
|
||||
|
||||
def ensure_schema_exists(self, schema: str) -> None:
|
||||
"""Ensure schema exists in IRIS database.
|
||||
|
||||
This method is idempotent and thread-safe. It uses a memory cache to avoid
|
||||
redundant database queries for already-verified schemas.
|
||||
|
||||
Args:
|
||||
schema: Schema name to ensure exists
|
||||
|
||||
Raises:
|
||||
Exception: If schema creation fails
|
||||
"""
|
||||
# Fast path: check cache first (no lock needed for read-only set lookup)
|
||||
if schema in self._schemas_initialized:
|
||||
return
|
||||
|
||||
# Slow path: acquire lock and check again (double-checked locking)
|
||||
with self._lock:
|
||||
if schema in self._schemas_initialized:
|
||||
return
|
||||
|
||||
# Get a connection to check/create schema
|
||||
conn = self._pool[0] if self._pool else self._create_connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
# Check if schema exists using INFORMATION_SCHEMA
|
||||
check_sql = """
|
||||
SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA
|
||||
WHERE SCHEMA_NAME = ?
|
||||
"""
|
||||
cursor.execute(check_sql, (schema,)) # Must be tuple or list
|
||||
exists = cursor.fetchone()[0] > 0
|
||||
|
||||
if not exists:
|
||||
# Schema doesn't exist, create it
|
||||
cursor.execute(f"CREATE SCHEMA {schema}")
|
||||
conn.commit()
|
||||
logger.info("Created schema: %s", schema)
|
||||
else:
|
||||
logger.debug("Schema already exists: %s", schema)
|
||||
|
||||
# Add to cache to skip future checks
|
||||
self._schemas_initialized.add(schema)
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.exception("Failed to ensure schema %s exists", schema)
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def close_all(self) -> None:
|
||||
"""Close all connections (application shutdown only)."""
|
||||
with self._lock:
|
||||
for conn in self._pool:
|
||||
try:
|
||||
conn.close()
|
||||
except (OSError, RuntimeError):
|
||||
pass
|
||||
self._pool.clear()
|
||||
self._in_use = 0
|
||||
self._schemas_initialized.clear()
|
||||
|
||||
|
||||
class IrisVector(BaseVector):
|
||||
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
|
||||
|
||||
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
|
||||
super().__init__(collection_name)
|
||||
self.config = config
|
||||
self.table_name = f"embedding_{collection_name}".upper()
|
||||
self.schema = config.IRIS_SCHEMA or "dify"
|
||||
self.pool = get_iris_pool(config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.IRIS
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
"""Context manager for database cursor with connection pooling."""
|
||||
conn = self.pool.get_connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
self.pool.return_connection(conn)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]:
|
||||
"""Add documents with embeddings to the collection."""
|
||||
added_ids = []
|
||||
with self._get_cursor() as cursor:
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4())
|
||||
metadata = json.dumps(doc.metadata) if doc.metadata else "{}"
|
||||
embedding_str = json.dumps(embeddings[i])
|
||||
|
||||
sql = f"INSERT INTO {self.schema}.{self.table_name} (id, text, meta, embedding) VALUES (?, ?, ?, ?)"
|
||||
cursor.execute(sql, (doc_id, doc.page_content, metadata, embedding_str))
|
||||
added_ids.append(doc_id)
|
||||
|
||||
return added_ids
|
||||
|
||||
def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin
|
||||
try:
|
||||
with self._get_cursor() as cursor:
|
||||
sql = f"SELECT 1 FROM {self.schema}.{self.table_name} WHERE id = ?"
|
||||
cursor.execute(sql, (id,))
|
||||
return cursor.fetchone() is not None
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
placeholders = ",".join(["?" for _ in ids])
|
||||
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})"
|
||||
cursor.execute(sql, ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
"""Delete documents by metadata field (JSON LIKE pattern matching)."""
|
||||
with self._get_cursor() as cursor:
|
||||
pattern = f'%"{key}": "{value}"%'
|
||||
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?"
|
||||
cursor.execute(sql, (pattern,))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search similar documents using VECTOR_COSINE with HNSW index."""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
embedding_str = json.dumps(query_vector)
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta, VECTOR_COSINE(embedding, ?) as score
|
||||
FROM {self.schema}.{self.table_name}
|
||||
ORDER BY score DESC
|
||||
"""
|
||||
cursor.execute(sql, (embedding_str,))
|
||||
|
||||
docs = []
|
||||
for row in cursor.fetchall():
|
||||
if len(row) >= 4:
|
||||
text, meta_str, score = row[1], row[2], float(row[3])
|
||||
if score >= score_threshold:
|
||||
metadata = json.loads(meta_str) if meta_str else {}
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search documents by full-text using iFind index or fallback to LIKE search."""
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
if self.config.IRIS_TEXT_INDEX:
|
||||
# Use iFind full-text search with index
|
||||
text_index_name = f"idx_{self.table_name}_text"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE %ID %FIND search_index({text_index_name}, ?)
|
||||
"""
|
||||
cursor.execute(sql, (query,))
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
query_pattern = f"%{query}%"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ?
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
docs = []
|
||||
for row in cursor.fetchall():
|
||||
if len(row) >= 3:
|
||||
metadata = json.loads(row[2]) if row[2] else {}
|
||||
docs.append(Document(page_content=row[1], metadata=metadata))
|
||||
|
||||
if not docs:
|
||||
logger.info("Full-text search for '%s' returned no results", query)
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the entire collection (drop table - permanent)."""
|
||||
with self._get_cursor() as cursor:
|
||||
sql = f"DROP TABLE {self.schema}.{self.table_name}"
|
||||
cursor.execute(sql)
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
"""Create table with VECTOR column and HNSW index.
|
||||
|
||||
Uses Redis lock to prevent concurrent creation attempts across multiple
|
||||
API server instances (api, worker, worker_beat).
|
||||
"""
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
|
||||
with redis_client.lock(lock_name, timeout=20): # pylint: disable=not-context-manager
|
||||
if redis_client.get(cache_key):
|
||||
return
|
||||
|
||||
# Ensure schema exists (idempotent, cached after first call)
|
||||
self.pool.ensure_schema_exists(self.schema)
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
# Create table with VECTOR column
|
||||
sql = f"""
|
||||
CREATE TABLE {self.schema}.{self.table_name} (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
text CLOB,
|
||||
meta CLOB,
|
||||
embedding VECTOR(DOUBLE, {dimension})
|
||||
)
|
||||
"""
|
||||
logger.info("Creating table: %s.%s", self.schema, self.table_name)
|
||||
cursor.execute(sql)
|
||||
|
||||
# Create HNSW index for vector similarity search
|
||||
index_name = f"idx_{self.table_name}_embedding"
|
||||
sql_index = (
|
||||
f"CREATE INDEX {index_name} ON {self.schema}.{self.table_name} "
|
||||
"(embedding) AS HNSW(Distance='Cosine')"
|
||||
)
|
||||
logger.info("Creating HNSW index: %s", index_name)
|
||||
cursor.execute(sql_index)
|
||||
logger.info("HNSW index created successfully: %s", index_name)
|
||||
|
||||
# Create full-text search index if enabled
|
||||
logger.info(
|
||||
"IRIS_TEXT_INDEX config value: %s (type: %s)",
|
||||
self.config.IRIS_TEXT_INDEX,
|
||||
type(self.config.IRIS_TEXT_INDEX),
|
||||
)
|
||||
if self.config.IRIS_TEXT_INDEX:
|
||||
text_index_name = f"idx_{self.table_name}_text"
|
||||
language = self.config.IRIS_TEXT_INDEX_LANGUAGE
|
||||
# Fixed: Removed extra parentheses and corrected syntax
|
||||
sql_text_index = f"""
|
||||
CREATE INDEX {text_index_name} ON {self.schema}.{self.table_name} (text)
|
||||
AS %iFind.Index.Basic
|
||||
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
|
||||
"""
|
||||
logger.info("Creating text index: %s with language: %s", text_index_name, language)
|
||||
logger.info("SQL for text index: %s", sql_text_index)
|
||||
cursor.execute(sql_text_index)
|
||||
logger.info("Text index created successfully: %s", text_index_name)
|
||||
else:
|
||||
logger.warning("Text index creation skipped - IRIS_TEXT_INDEX is disabled")
|
||||
|
||||
redis_client.set(cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class IrisVectorFactory(AbstractVectorFactory):
|
||||
"""Factory for creating IrisVector instances."""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = self.gen_index_struct_dict(VectorType.IRIS, collection_name)
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
return IrisVector(
|
||||
collection_name=collection_name,
|
||||
config=IrisVectorConfig(
|
||||
IRIS_HOST=dify_config.IRIS_HOST,
|
||||
IRIS_SUPER_SERVER_PORT=dify_config.IRIS_SUPER_SERVER_PORT,
|
||||
IRIS_USER=dify_config.IRIS_USER,
|
||||
IRIS_PASSWORD=dify_config.IRIS_PASSWORD,
|
||||
IRIS_DATABASE=dify_config.IRIS_DATABASE,
|
||||
IRIS_SCHEMA=dify_config.IRIS_SCHEMA,
|
||||
IRIS_CONNECTION_URL=dify_config.IRIS_CONNECTION_URL,
|
||||
IRIS_MIN_CONNECTION=dify_config.IRIS_MIN_CONNECTION,
|
||||
IRIS_MAX_CONNECTION=dify_config.IRIS_MAX_CONNECTION,
|
||||
IRIS_TEXT_INDEX=dify_config.IRIS_TEXT_INDEX,
|
||||
IRIS_TEXT_INDEX_LANGUAGE=dify_config.IRIS_TEXT_INDEX_LANGUAGE,
|
||||
),
|
||||
)
|
||||
|
|
@ -289,7 +289,8 @@ class OracleVector(BaseVector):
|
|||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||
# `nr`: Person, `ns`: Location, `nt`: Organization
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
|
|
|
|||
|
|
@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
|
|||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# Vastbase 支持的向量维度取值范围为 [1,16000]
|
||||
# Vastbase supports vector dimensions in the range [1, 16,000]
|
||||
if dimension <= 16000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ class Vector:
|
|||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
|
||||
|
||||
return LindormVectorStoreFactory
|
||||
case VectorType.OCEANBASE:
|
||||
case VectorType.OCEANBASE | VectorType.SEEKDB:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
return OceanBaseVectorFactory
|
||||
|
|
@ -187,6 +187,10 @@ class Vector:
|
|||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case VectorType.IRIS:
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
|
||||
|
||||
return IrisVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
|
|
|||
|
|
@ -27,8 +27,10 @@ class VectorType(StrEnum):
|
|||
UPSTASH = "upstash"
|
||||
TIDB_ON_QDRANT = "tidb_on_qdrant"
|
||||
OCEANBASE = "oceanbase"
|
||||
SEEKDB = "seekdb"
|
||||
OPENGAUSS = "opengauss"
|
||||
TABLESTORE = "tablestore"
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
IRIS = "iris"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class NotionInfo(BaseModel):
|
|||
"""
|
||||
|
||||
credential_id: str | None = None
|
||||
notion_workspace_id: str
|
||||
notion_workspace_id: str | None = ""
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
document: Document | None = None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
|
|
@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
|||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class Candidate(TypedDict):
|
||||
idx: int
|
||||
count: int
|
||||
map: dict[int, str]
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
|
|
@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor):
|
|||
file_extension = os.path.splitext(self._file_path)[-1].lower()
|
||||
|
||||
if file_extension == ".xlsx":
|
||||
wb = load_workbook(self._file_path, data_only=True)
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
data = sheet.values
|
||||
cols = next(data, None)
|
||||
if cols is None:
|
||||
continue
|
||||
df = pd.DataFrame(data, columns=cols)
|
||||
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for index, row in df.iterrows():
|
||||
page_content = []
|
||||
for col_index, (k, v) in enumerate(row.items()):
|
||||
if pd.notna(v):
|
||||
cell = sheet.cell(
|
||||
row=cast(int, index) + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
page_content.append(f'"{k}":"{value}"')
|
||||
else:
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
wb = load_workbook(self._file_path, read_only=True, data_only=True)
|
||||
try:
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet)
|
||||
if not column_map:
|
||||
continue
|
||||
start_row = header_row_idx + 1
|
||||
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
|
||||
if all(cell.value is None for cell in row):
|
||||
continue
|
||||
page_content = []
|
||||
for col_idx, cell in enumerate(row):
|
||||
value = cell.value
|
||||
if col_idx in column_map:
|
||||
col_name = column_map[col_idx]
|
||||
if hasattr(cell, "hyperlink") and cell.hyperlink:
|
||||
target = getattr(cell.hyperlink, "target", None)
|
||||
if target:
|
||||
value = f"[{value}]({target})"
|
||||
if value is None:
|
||||
value = ""
|
||||
elif not isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.strip().replace('"', '\\"')
|
||||
page_content.append(f'"{col_name}":"{value}"')
|
||||
if page_content:
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
finally:
|
||||
wb.close()
|
||||
|
||||
elif file_extension == ".xls":
|
||||
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
|
||||
|
|
@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor):
|
|||
df = excel_file.parse(sheet_name=excel_sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
for _, series_row in df.iterrows():
|
||||
page_content = []
|
||||
for k, v in row.items():
|
||||
for k, v in series_row.items():
|
||||
if pd.notna(v):
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(
|
||||
|
|
@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor):
|
|||
raise ValueError(f"Unsupported file extension: {file_extension}")
|
||||
|
||||
return documents
|
||||
|
||||
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
|
||||
"""
|
||||
Scan first N rows to find the most likely header row.
|
||||
Returns:
|
||||
header_row_idx: 1-based index of the header row
|
||||
column_map: Dict mapping 0-based column index to column name
|
||||
max_col_idx: 1-based index of the last valid column (for iter_rows boundary)
|
||||
"""
|
||||
# Store potential candidates: (row_index, non_empty_count, column_map)
|
||||
candidates: list[Candidate] = []
|
||||
|
||||
# Limit scan to avoid performance issues on huge files
|
||||
# We iterate manually to control the read scope
|
||||
for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1):
|
||||
# Filter out empty cells and build a temp map for this row
|
||||
# col_idx is 0-based
|
||||
row_map = {}
|
||||
for col_idx, cell_value in enumerate(row):
|
||||
if cell_value is not None and str(cell_value).strip():
|
||||
row_map[col_idx] = str(cell_value).strip().replace('"', '\\"')
|
||||
|
||||
if not row_map:
|
||||
continue
|
||||
|
||||
non_empty_count = len(row_map)
|
||||
|
||||
# Header selection heuristic (implemented):
|
||||
# - Prefer the first row with at least 2 non-empty columns.
|
||||
# - Fallback: choose the row with the most non-empty columns
|
||||
# (tie-breaker: smaller row index).
|
||||
candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map})
|
||||
|
||||
if not candidates:
|
||||
return 0, {}, 0
|
||||
|
||||
# Choose the best candidate header row.
|
||||
|
||||
best_candidate: Candidate | None = None
|
||||
|
||||
# Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback.
|
||||
|
||||
for cand in candidates:
|
||||
if cand["count"] >= 2:
|
||||
best_candidate = cand
|
||||
break
|
||||
|
||||
# Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns
|
||||
if not best_candidate:
|
||||
# Sort by count desc, then index asc
|
||||
candidates.sort(key=lambda x: (-x["count"], x["idx"]))
|
||||
best_candidate = candidates[0]
|
||||
|
||||
# Determine max_col_idx (1-based for openpyxl)
|
||||
# It is the index of the last valid column in our map + 1
|
||||
max_col_idx = max(best_candidate["map"].keys()) + 1
|
||||
|
||||
return best_candidate["idx"], best_candidate["map"], max_col_idx
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ class ExtractProcessor:
|
|||
elif extract_setting.datasource_type == DatasourceType.NOTION:
|
||||
assert extract_setting.notion_info is not None, "notion_info is required"
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
|
||||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "",
|
||||
notion_obj_id=extract_setting.notion_info.notion_obj_id,
|
||||
notion_page_type=extract_setting.notion_info.notion_page_type,
|
||||
document_model=extract_setting.notion_info.document,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,6 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
|
|||
except concurrent.futures.TimeoutError:
|
||||
raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}")
|
||||
|
||||
if all(encoding["encoding"] is None for encoding in encodings):
|
||||
if all(encoding.encoding is None for encoding in encodings):
|
||||
raise RuntimeError(f"Could not detect encoding for {file_path}")
|
||||
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
|
||||
return [enc for enc in encodings if enc.encoding is not None]
|
||||
|
|
|
|||
|
|
@ -83,23 +83,46 @@ class WordExtractor(BaseExtractor):
|
|||
def _extract_images_from_docx(self, doc):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
for rel in doc.part.rels.values():
|
||||
for r_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.target_ref:
|
||||
image_count += 1
|
||||
if rel.is_external:
|
||||
url = rel.target_ref
|
||||
response = ssrf_proxy.get(url)
|
||||
if not self._is_valid_url(url):
|
||||
continue
|
||||
try:
|
||||
response = ssrf_proxy.get(url)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
|
||||
continue
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", ""))
|
||||
if image_ext is None:
|
||||
continue
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
storage.save(file_key, response.content)
|
||||
else:
|
||||
continue
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
image_map[r_id] = f""
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
|
|
@ -110,27 +133,25 @@ class WordExtractor(BaseExtractor):
|
|||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
|
||||
storage.save(file_key, rel.target_part.blob)
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = f""
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
image_map[rel.target_part] = f""
|
||||
db.session.commit()
|
||||
return image_map
|
||||
|
||||
def _table_to_markdown(self, table, image_map):
|
||||
|
|
@ -186,11 +207,17 @@ class WordExtractor(BaseExtractor):
|
|||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
image_part = paragraph.part.rels[image_id].target_part
|
||||
|
||||
if image_part in image_map:
|
||||
image_link = image_map[image_part]
|
||||
paragraph_content.append(image_link)
|
||||
rel = paragraph.part.rels.get(image_id)
|
||||
if rel is None:
|
||||
continue
|
||||
# For external images, use image_id as key; for internal, use target_part
|
||||
if rel.is_external:
|
||||
if image_id in image_map:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
else:
|
||||
paragraph_content.append(run.text)
|
||||
return "".join(paragraph_content).strip()
|
||||
|
|
@ -227,6 +254,18 @@ class WordExtractor(BaseExtractor):
|
|||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
def append_image_link(image_id, has_drawing):
|
||||
"""Helper to append image link from image_map based on relationship type."""
|
||||
rel = doc.part.rels[image_id]
|
||||
if rel.is_external:
|
||||
if image_id in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
|
|
@ -243,10 +282,18 @@ class WordExtractor(BaseExtractor):
|
|||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
|
||||
)
|
||||
if embed_id:
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
rel = doc.part.rels.get(embed_id)
|
||||
if rel is not None and rel.is_external:
|
||||
# External image: use embed_id as key
|
||||
if embed_id in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[embed_id])
|
||||
else:
|
||||
# Internal image: use target_part as key
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
|
|
@ -261,9 +308,7 @@ class WordExtractor(BaseExtractor):
|
|||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
append_image_link(image_id, has_drawing)
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
|
|
@ -271,9 +316,7 @@ class WordExtractor(BaseExtractor):
|
|||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
append_image_link(image_id, has_drawing)
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
|
|
|||
|
|
@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum):
|
|||
notion_import = "notion"
|
||||
local_file = "file_upload"
|
||||
online_document = "online_document"
|
||||
online_drive = "online_drive"
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ class BaseIndexProcessor(ABC):
|
|||
|
||||
if not filename:
|
||||
parsed_url = urlparse(image_url)
|
||||
# unquote 处理 URL 中的中文
|
||||
# Decode percent-encoded characters in the URL path.
|
||||
path = unquote(parsed_url.path)
|
||||
filename = os.path.basename(path)
|
||||
|
||||
|
|
|
|||
|
|
@ -151,20 +151,14 @@ class DatasetRetrieval:
|
|||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||
for dataset in datasets:
|
||||
if dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
|
||||
if inputs:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
else:
|
||||
|
|
@ -282,26 +276,35 @@ class DatasetRetrieval:
|
|||
)
|
||||
context_files.append(attachment_info)
|
||||
if show_retrieve_source:
|
||||
dataset_ids = [record.segment.dataset_id for record in records]
|
||||
document_ids = [record.segment.document_id for record in records]
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id.in_(document_ids),
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
|
||||
dataset_stmt = select(Dataset).where(
|
||||
Dataset.id.in_(dataset_ids),
|
||||
)
|
||||
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
dataset_item = dataset_map.get(segment.dataset_id)
|
||||
document_item = document_map.get(segment.document_id)
|
||||
if dataset_item and document_item:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
dataset_id=dataset_item.id,
|
||||
dataset_name=dataset_item.name,
|
||||
document_id=document_item.id,
|
||||
document_name=document_item.name,
|
||||
data_source_type=document_item.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata,
|
||||
doc_metadata=document_item.doc_metadata,
|
||||
)
|
||||
|
||||
if invoke_from.to_source() == "dev":
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._fixed_separator = fixed_separator
|
||||
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
|
||||
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
|
|
@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
splits = re.split(r" +", text)
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
|
||||
if self._keep_separator:
|
||||
splits = [s + separator for s in splits[:-1]] + splits[-1:]
|
||||
else:
|
||||
splits = list(text)
|
||||
if separator == "\n":
|
||||
|
|
@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = separator if self._keep_separator else ""
|
||||
_separator = "" if self._keep_separator else separator
|
||||
s_lens = self._length_function(splits)
|
||||
if separator != "":
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
|
|
|
|||
|
|
@ -101,6 +101,8 @@ class ToolFileMessageTransformer:
|
|||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
if not mimetype:
|
||||
mimetype = "application/octet-stream"
|
||||
# get filename from meta
|
||||
filename = meta.get("filename", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
|
|
|||
|
|
@ -140,6 +140,10 @@ class GraphEngine:
|
|||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
|
|
@ -158,6 +162,7 @@ class GraphEngine:
|
|||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
min_workers=self._min_workers,
|
||||
|
|
@ -196,10 +201,6 @@ class GraphEngine:
|
|||
event_emitter=self._event_manager,
|
||||
)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
# === Validation ===
|
||||
# Ensure all nodes share the same GraphRuntimeState instance
|
||||
self._validate_graph_state_consistency()
|
||||
|
|
|
|||
|
|
@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut
|
|||
from .base import GraphEngineLayer
|
||||
from .debug_logging import DebugLoggingLayer
|
||||
from .execution_limits import ExecutionLimitsLayer
|
||||
from .observability import ObservabilityLayer
|
||||
|
||||
__all__ = [
|
||||
"DebugLoggingLayer",
|
||||
"ExecutionLimitsLayer",
|
||||
"GraphEngineLayer",
|
||||
"ObservabilityLayer",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
|
|||
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
|
|
@ -83,3 +84,29 @@ class GraphEngineLayer(ABC):
|
|||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_node_run_start(self, node: Node) -> None: # noqa: B027
|
||||
"""
|
||||
Called immediately before a node begins execution.
|
||||
|
||||
Layers can override to inject behavior (e.g., start spans) prior to node execution.
|
||||
The node's execution ID is available via `node._node_execution_id` and will be
|
||||
consistent with all events emitted by this node execution.
|
||||
|
||||
Args:
|
||||
node: The node instance about to be executed
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
|
||||
"""
|
||||
Called after a node finishes execution.
|
||||
|
||||
The node's execution ID is available via `node._node_execution_id` and matches
|
||||
the `id` field in all events emitted by this node execution.
|
||||
|
||||
Args:
|
||||
node: The node instance that just finished execution
|
||||
error: Exception instance if the node failed, otherwise None
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
"""
|
||||
Node-level OpenTelemetry parser interfaces and defaults.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Protocol
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.status import Status, StatusCode
|
||||
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
|
||||
|
||||
class NodeOTelParser(Protocol):
|
||||
"""Parser interface for node-specific OpenTelemetry enrichment."""
|
||||
|
||||
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
|
||||
|
||||
|
||||
class DefaultNodeOTelParser:
|
||||
"""Fallback parser used when no node-specific parser is registered."""
|
||||
|
||||
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
|
||||
span.set_attribute("node.id", node.id)
|
||||
if node.execution_id:
|
||||
span.set_attribute("node.execution_id", node.execution_id)
|
||||
if hasattr(node, "node_type") and node.node_type:
|
||||
span.set_attribute("node.type", node.node_type.value)
|
||||
|
||||
if error:
|
||||
span.record_exception(error)
|
||||
span.set_status(Status(StatusCode.ERROR, str(error)))
|
||||
else:
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
|
||||
|
||||
class ToolNodeOTelParser:
|
||||
"""Parser for tool nodes that captures tool-specific metadata."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._delegate = DefaultNodeOTelParser()
|
||||
|
||||
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
|
||||
self._delegate.parse(node=node, span=span, error=error)
|
||||
|
||||
tool_data = getattr(node, "_node_data", None)
|
||||
if not isinstance(tool_data, ToolNodeData):
|
||||
return
|
||||
|
||||
span.set_attribute("tool.provider.id", tool_data.provider_id)
|
||||
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
|
||||
span.set_attribute("tool.provider.name", tool_data.provider_name)
|
||||
span.set_attribute("tool.name", tool_data.tool_name)
|
||||
span.set_attribute("tool.label", tool_data.tool_label)
|
||||
if tool_data.plugin_unique_identifier:
|
||||
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
|
||||
if tool_data.credential_id:
|
||||
span.set_attribute("tool.credential.id", tool_data.credential_id)
|
||||
if tool_data.tool_configurations:
|
||||
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
Observability layer for GraphEngine.
|
||||
|
||||
This layer creates OpenTelemetry spans for node execution, enabling distributed
|
||||
tracing of workflow execution. It establishes OTel context during node execution
|
||||
so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically
|
||||
associates with the node span.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, final
|
||||
|
||||
from opentelemetry import context as context_api
|
||||
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.node_parsers import (
|
||||
DefaultNodeOTelParser,
|
||||
NodeOTelParser,
|
||||
ToolNodeOTelParser,
|
||||
)
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeSpanContext:
|
||||
span: "Span"
|
||||
token: object
|
||||
|
||||
|
||||
@final
|
||||
class ObservabilityLayer(GraphEngineLayer):
|
||||
"""
|
||||
Layer that creates OpenTelemetry spans for node execution.
|
||||
|
||||
This layer:
|
||||
- Creates a span when a node starts execution
|
||||
- Establishes OTel context so automatic instrumentation associates with the span
|
||||
- Sets complete attributes and status when node execution ends
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._node_contexts: dict[str, _NodeSpanContext] = {}
|
||||
self._parsers: dict[NodeType, NodeOTelParser] = {}
|
||||
self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser())
|
||||
self._is_disabled: bool = False
|
||||
self._tracer: Tracer | None = None
|
||||
self._build_parser_registry()
|
||||
self._init_tracer()
|
||||
|
||||
def _init_tracer(self) -> None:
|
||||
"""Initialize OpenTelemetry tracer in constructor."""
|
||||
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
|
||||
self._is_disabled = True
|
||||
return
|
||||
|
||||
try:
|
||||
self._tracer = get_tracer(__name__)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get OpenTelemetry tracer: %s", e)
|
||||
self._is_disabled = True
|
||||
|
||||
def _build_parser_registry(self) -> None:
|
||||
"""Initialize parser registry for node types."""
|
||||
self._parsers = {
|
||||
NodeType.TOOL: ToolNodeOTelParser(),
|
||||
}
|
||||
|
||||
def _get_parser(self, node: Node) -> NodeOTelParser:
|
||||
node_type = getattr(node, "node_type", None)
|
||||
if isinstance(node_type, NodeType):
|
||||
return self._parsers.get(node_type, self._default_parser)
|
||||
return self._default_parser
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self._node_contexts.clear()
|
||||
|
||||
@override
|
||||
def on_node_run_start(self, node: Node) -> None:
|
||||
"""
|
||||
Called when a node starts execution.
|
||||
|
||||
Creates a span and establishes OTel context for automatic instrumentation.
|
||||
"""
|
||||
if self._is_disabled:
|
||||
return
|
||||
|
||||
try:
|
||||
if not self._tracer:
|
||||
return
|
||||
|
||||
execution_id = node.execution_id
|
||||
if not execution_id:
|
||||
return
|
||||
|
||||
parent_context = context_api.get_current()
|
||||
span = self._tracer.start_span(
|
||||
f"{node.title}",
|
||||
kind=SpanKind.INTERNAL,
|
||||
context=parent_context,
|
||||
)
|
||||
|
||||
new_context = set_span_in_context(span)
|
||||
token = context_api.attach(new_context)
|
||||
|
||||
self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
|
||||
|
||||
@override
|
||||
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when a node finishes execution.
|
||||
|
||||
Sets complete attributes, records exceptions, and ends the span.
|
||||
"""
|
||||
if self._is_disabled:
|
||||
return
|
||||
|
||||
try:
|
||||
execution_id = node.execution_id
|
||||
if not execution_id:
|
||||
return
|
||||
node_context = self._node_contexts.get(execution_id)
|
||||
if not node_context:
|
||||
return
|
||||
|
||||
span = node_context.span
|
||||
parser = self._get_parser(node)
|
||||
try:
|
||||
parser.parse(node=node, span=span, error=error)
|
||||
span.end()
|
||||
finally:
|
||||
token = node_context.token
|
||||
if token is not None:
|
||||
try:
|
||||
context_api.detach(token)
|
||||
except Exception:
|
||||
logger.warning("Failed to detach OpenTelemetry token: %s", token)
|
||||
self._node_contexts.pop(execution_id, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e)
|
||||
|
||||
@override
|
||||
def on_event(self, event) -> None:
|
||||
"""Not used in this layer."""
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._node_contexts:
|
||||
logger.warning(
|
||||
"ObservabilityLayer: %d node spans were not properly ended",
|
||||
len(self._node_contexts),
|
||||
)
|
||||
self._node_contexts.clear()
|
||||
|
|
@ -9,6 +9,7 @@ import contextvars
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
|
|
@ -17,6 +18,7 @@ from flask import Flask
|
|||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
|
@ -39,6 +41,7 @@ class Worker(threading.Thread):
|
|||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
|
|
@ -50,6 +53,7 @@ class Worker(threading.Thread):
|
|||
ready_queue: Ready queue containing node IDs ready for execution
|
||||
event_queue: Queue for pushing execution events
|
||||
graph: Graph containing nodes to execute
|
||||
layers: Graph engine layers for node execution hooks
|
||||
worker_id: Unique identifier for this worker
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
|
|
@ -63,6 +67,7 @@ class Worker(threading.Thread):
|
|||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
self._layers = layers if layers is not None else []
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
|
|
@ -122,20 +127,51 @@ class Worker(threading.Thread):
|
|||
Args:
|
||||
node: The node instance to execute
|
||||
"""
|
||||
# Execute the node with preserved context if Flask app is provided
|
||||
node.ensure_execution_id()
|
||||
|
||||
error: Exception | None = None
|
||||
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
# Execute the node
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
self._event_queue.put(event)
|
||||
except Exception as exc:
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
self._invoke_node_run_end_hooks(node, error)
|
||||
else:
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self._event_queue.put(event)
|
||||
else:
|
||||
# Execute without context preservation
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self._event_queue.put(event)
|
||||
except Exception as exc:
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
self._invoke_node_run_end_hooks(node, error)
|
||||
|
||||
def _invoke_node_run_start_hooks(self, node: Node) -> None:
|
||||
"""Invoke on_node_run_start hooks for all layers."""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_node_run_start(node)
|
||||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
||||
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
|
||||
"""Invoke on_node_run_end hooks for all layers."""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_node_run_end(node, error)
|
||||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from configs import dify_config
|
|||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..layers.base import GraphEngineLayer
|
||||
from ..ready_queue import ReadyQueue
|
||||
from ..worker import Worker
|
||||
|
||||
|
|
@ -39,6 +40,7 @@ class WorkerPool:
|
|||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
|
|
@ -53,6 +55,7 @@ class WorkerPool:
|
|||
ready_queue: Ready queue for nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
layers: Graph engine layers for node execution hooks
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
min_workers: Minimum number of workers
|
||||
|
|
@ -65,6 +68,7 @@ class WorkerPool:
|
|||
self._graph = graph
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._layers = layers
|
||||
|
||||
# Scaling parameters with defaults
|
||||
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||
|
|
@ -144,6 +148,7 @@ class WorkerPool:
|
|||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
|
|
|
|||
|
|
@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]):
|
|||
def graph_init_params(self) -> "GraphInitParams":
|
||||
return self._graph_init_params
|
||||
|
||||
@property
|
||||
def execution_id(self) -> str:
|
||||
return self._node_execution_id
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
return self._node_execution_id
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
|
|
@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]):
|
|||
raise NotImplementedError
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
# Generate a single node execution ID to use for all events
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.title,
|
||||
|
|
@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]):
|
|||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event.id = self._node_execution_id
|
||||
event.id = self.execution_id
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
|
@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]):
|
|||
error_type="WorkflowNodeError",
|
||||
)
|
||||
yield NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
|
|
@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]):
|
|||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
|
|
@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]):
|
|||
)
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
|
|
@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
|
|
@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]):
|
|||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
|
|
@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]):
|
|||
)
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
|
|
@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
|
||||
return NodeRunPauseRequestedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
||||
|
|
@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
return NodeRunAgentLogEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
message_id=event.message_id,
|
||||
|
|
@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
|
|
@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||
return NodeRunLoopNextEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
|
|
@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||
return NodeRunLoopSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
|
|
@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||
return NodeRunLoopFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
|
|
@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||
return NodeRunIterationStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
|
|
@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||
return NodeRunIterationNextEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
|
|
@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||
return NodeRunIterationSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
|
|
@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||
return NodeRunIterationFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
|
|
@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]):
|
|||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self._node_execution_id,
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
retriever_resources=event.retriever_resources,
|
||||
|
|
|
|||
|
|
@ -86,6 +86,11 @@ class Executor:
|
|||
node_data.authorization.config.api_key = variable_pool.convert_template(
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
# Validate that API key is not empty after template conversion
|
||||
if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip():
|
||||
raise AuthorizationConfigError(
|
||||
"API key is required for authorization but was empty. Please provide a valid API key."
|
||||
)
|
||||
|
||||
self.url = node_data.url
|
||||
self.method = node_data.method
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
|
|
@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]):
|
|||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{key} must be a JSON object")
|
||||
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
Draft7Validator(schema).validate(value)
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{schema} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{value} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
Draft7Validator(json_schema).validate(json_value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = value
|
||||
node_inputs[key] = json_value
|
||||
|
|
|
|||
|
|
@ -1,14 +1,22 @@
|
|||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import FileTransferMethod
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import FileVariable
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from factories import file_factory
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ContentType, WebhookData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerWebhookNode(Node[WebhookData]):
|
||||
node_type = NodeType.TRIGGER_WEBHOOK
|
||||
|
|
@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]):
|
|||
outputs=outputs,
|
||||
)
|
||||
|
||||
def generate_file_var(self, param_name: str, file: dict):
|
||||
related_id = file.get("related_id")
|
||||
transfer_method_value = file.get("transfer_method")
|
||||
if transfer_method_value:
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
match transfer_method:
|
||||
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL:
|
||||
file["upload_file_id"] = related_id
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file["tool_file_id"] = related_id
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
file["datasource_file_id"] = related_id
|
||||
|
||||
try:
|
||||
file_obj = file_factory.build_from_mapping(
|
||||
mapping=file,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
|
||||
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to build FileVariable for webhook file parameter %s",
|
||||
param_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract outputs based on node configuration from webhook inputs."""
|
||||
outputs = {}
|
||||
|
|
@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]):
|
|||
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
|
||||
continue
|
||||
elif self.node_data.content_type == ContentType.BINARY:
|
||||
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
|
||||
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
|
||||
file_var = self.generate_file_var(param_name, raw_data)
|
||||
if file_var:
|
||||
outputs[param_name] = file_var
|
||||
else:
|
||||
outputs[param_name] = raw_data
|
||||
continue
|
||||
|
||||
if param_type == "file":
|
||||
# Get File object (already processed by webhook controller)
|
||||
file_obj = webhook_data.get("files", {}).get(param_name)
|
||||
outputs[param_name] = file_obj
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
file = files.get(param_name)
|
||||
if file and isinstance(file, dict):
|
||||
file_var = self.generate_file_var(param_name, file)
|
||||
if file_var:
|
||||
outputs[param_name] = file_var
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
# Get regular body parameter
|
||||
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
|
||||
|
||||
# Include raw webhook data for debugging/advanced use
|
||||
outputs["_webhook_raw"] = webhook_data
|
||||
|
||||
return outputs
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue