mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/grouping-branching
This commit is contained in:
commit
035f51ad58
|
|
@ -1,13 +1,13 @@
|
|||
---
|
||||
name: Dify Frontend Testing
|
||||
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
|
||||
name: frontend-testing
|
||||
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
|
||||
---
|
||||
|
||||
# Dify Frontend Testing Skill
|
||||
|
||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||
|
||||
## When to Apply This Skill
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ Apply this skill when the user:
|
|||
|
||||
- Asks to **write tests** for a component, hook, or utility
|
||||
- Asks to **review existing tests** for completeness
|
||||
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
|
||||
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
|
||||
- Requests **test coverage** improvement
|
||||
- Uses `pnpm analyze-component` output as context
|
||||
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
||||
|
|
@ -33,9 +33,9 @@ Apply this skill when the user:
|
|||
|
||||
| Tool | Version | Purpose |
|
||||
|------|---------|---------|
|
||||
| Jest | 29.7 | Test runner |
|
||||
| Vitest | 4.0.16 | Test runner |
|
||||
| React Testing Library | 16.0 | Component testing |
|
||||
| happy-dom | - | Test environment |
|
||||
| jsdom | - | Test environment |
|
||||
| nock | 14.0 | HTTP mocking |
|
||||
| TypeScript | 5.x | Type safety |
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ Apply this skill when the user:
|
|||
pnpm test
|
||||
|
||||
# Watch mode
|
||||
pnpm test -- --watch
|
||||
pnpm test:watch
|
||||
|
||||
# Run specific file
|
||||
pnpm test -- path/to/file.spec.tsx
|
||||
|
|
@ -77,9 +77,9 @@ import Component from './index'
|
|||
// import { ChildComponent } from './child-component'
|
||||
|
||||
// ✅ Mock external dependencies only
|
||||
jest.mock('@/service/api')
|
||||
jest.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: jest.fn() }),
|
||||
vi.mock('@/service/api')
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
usePathname: () => '/test',
|
||||
}))
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ let mockSharedState = false
|
|||
|
||||
describe('ComponentName', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
||||
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
||||
mockSharedState = false // ✅ Reset shared state
|
||||
})
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ describe('ComponentName', () => {
|
|||
// User Interactions
|
||||
describe('User Interactions', () => {
|
||||
it('should handle click events', () => {
|
||||
const handleClick = jest.fn()
|
||||
const handleClick = vi.fn()
|
||||
render(<Component onClick={handleClick} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
|
@ -178,7 +178,7 @@ Process in this order for multi-file testing:
|
|||
- **500+ lines**: Consider splitting before testing
|
||||
- **Many dependencies**: Extract logic into hooks first
|
||||
|
||||
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
|
||||
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
|
|
@ -289,17 +289,18 @@ For each test file generated, aim for:
|
|||
- ✅ **>95%** branch coverage
|
||||
- ✅ **>95%** line coverage
|
||||
|
||||
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
|
||||
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
|
||||
|
||||
## Detailed Guides
|
||||
|
||||
For more detailed information, refer to:
|
||||
|
||||
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||
- `guides/mocking.md` - Mock patterns and best practices
|
||||
- `guides/async-testing.md` - Async operations and API calls
|
||||
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||
- `guides/common-patterns.md` - Frequently used testing patterns
|
||||
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||
- `references/mocking.md` - Mock patterns and best practices
|
||||
- `references/async-testing.md` - Async operations and API calls
|
||||
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||
- `references/common-patterns.md` - Frequently used testing patterns
|
||||
- `references/checklist.md` - Test generation checklist and validation steps
|
||||
|
||||
## Authoritative References
|
||||
|
||||
|
|
@ -315,7 +316,7 @@ For more detailed information, refer to:
|
|||
|
||||
### Project Configuration
|
||||
|
||||
- `web/jest.config.ts` - Jest configuration
|
||||
- `web/jest.setup.ts` - Test environment setup
|
||||
- `web/vitest.config.ts` - Vitest configuration
|
||||
- `web/vitest.setup.ts` - Test environment setup
|
||||
- `web/testing/analyze-component.js` - Component analysis tool
|
||||
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)
|
||||
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ import userEvent from '@testing-library/user-event'
|
|||
// ============================================================================
|
||||
// Mocks
|
||||
// ============================================================================
|
||||
// WHY: Mocks must be hoisted to top of file (Jest requirement).
|
||||
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
|
||||
// They run BEFORE imports, so keep them before component imports.
|
||||
|
||||
// i18n (automatically mocked)
|
||||
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
|
||||
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||
// No explicit mock needed - it returns translation keys as-is
|
||||
// Override only if custom translations are required:
|
||||
// jest.mock('react-i18next', () => ({
|
||||
// vi.mock('react-i18next', () => ({
|
||||
// useTranslation: () => ({
|
||||
// t: (key: string) => {
|
||||
// const customTranslations: Record<string, string> = {
|
||||
|
|
@ -43,17 +43,17 @@ import userEvent from '@testing-library/user-event'
|
|||
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = jest.fn()
|
||||
// jest.mock('next/navigation', () => ({
|
||||
// const mockPush = vi.fn()
|
||||
// vi.mock('next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
||||
// API services (if component fetches data)
|
||||
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
||||
// jest.mock('@/service/api')
|
||||
// vi.mock('@/service/api')
|
||||
// import * as api from '@/service/api'
|
||||
// const mockedApi = api as jest.Mocked<typeof api>
|
||||
// const mockedApi = vi.mocked(api)
|
||||
|
||||
// Shared mock state (for portal/dropdown components)
|
||||
// WHY: Portal components like PortalToFollowElem need shared state between
|
||||
|
|
@ -98,7 +98,7 @@ describe('ComponentName', () => {
|
|||
// - Prevents mock call history from leaking between tests
|
||||
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
||||
// mockOpenState = false
|
||||
})
|
||||
|
|
@ -155,7 +155,7 @@ describe('ComponentName', () => {
|
|||
// - userEvent simulates real user behavior (focus, hover, then click)
|
||||
// - fireEvent is lower-level, doesn't trigger all browser events
|
||||
// const user = userEvent.setup()
|
||||
// const handleClick = jest.fn()
|
||||
// const handleClick = vi.fn()
|
||||
// render(<ComponentName onClick={handleClick} />)
|
||||
//
|
||||
// await user.click(screen.getByRole('button'))
|
||||
|
|
@ -165,7 +165,7 @@ describe('ComponentName', () => {
|
|||
|
||||
it('should call onChange when value changes', async () => {
|
||||
// const user = userEvent.setup()
|
||||
// const handleChange = jest.fn()
|
||||
// const handleChange = vi.fn()
|
||||
// render(<ComponentName onChange={handleChange} />)
|
||||
//
|
||||
// await user.type(screen.getByRole('textbox'), 'new value')
|
||||
|
|
@ -15,9 +15,9 @@ import { renderHook, act, waitFor } from '@testing-library/react'
|
|||
// ============================================================================
|
||||
|
||||
// API services (if hook fetches data)
|
||||
// jest.mock('@/service/api')
|
||||
// vi.mock('@/service/api')
|
||||
// import * as api from '@/service/api'
|
||||
// const mockedApi = api as jest.Mocked<typeof api>
|
||||
// const mockedApi = vi.mocked(api)
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
|
|
@ -38,7 +38,7 @@ import { renderHook, act, waitFor } from '@testing-library/react'
|
|||
|
||||
describe('useHookName', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
|
@ -145,7 +145,7 @@ describe('useHookName', () => {
|
|||
// --------------------------------------------------------------------------
|
||||
describe('Side Effects', () => {
|
||||
it('should call callback when value changes', () => {
|
||||
// const callback = jest.fn()
|
||||
// const callback = vi.fn()
|
||||
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
||||
//
|
||||
// act(() => {
|
||||
|
|
@ -156,9 +156,9 @@ describe('useHookName', () => {
|
|||
})
|
||||
|
||||
it('should cleanup on unmount', () => {
|
||||
// const cleanup = jest.fn()
|
||||
// jest.spyOn(window, 'addEventListener')
|
||||
// jest.spyOn(window, 'removeEventListener')
|
||||
// const cleanup = vi.fn()
|
||||
// vi.spyOn(window, 'addEventListener')
|
||||
// vi.spyOn(window, 'removeEventListener')
|
||||
//
|
||||
// const { unmount } = renderHook(() => useHookName())
|
||||
//
|
||||
|
|
@ -49,7 +49,7 @@ import userEvent from '@testing-library/user-event'
|
|||
|
||||
it('should submit form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = jest.fn()
|
||||
const onSubmit = vi.fn()
|
||||
|
||||
render(<Form onSubmit={onSubmit} />)
|
||||
|
||||
|
|
@ -77,15 +77,15 @@ it('should submit form', async () => {
|
|||
```typescript
|
||||
describe('Debounced Search', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers()
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should debounce search input', async () => {
|
||||
const onSearch = jest.fn()
|
||||
const onSearch = vi.fn()
|
||||
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
||||
|
||||
// Type in the input
|
||||
|
|
@ -95,7 +95,7 @@ describe('Debounced Search', () => {
|
|||
expect(onSearch).not.toHaveBeenCalled()
|
||||
|
||||
// Advance timers
|
||||
jest.advanceTimersByTime(300)
|
||||
vi.advanceTimersByTime(300)
|
||||
|
||||
// Now search is called
|
||||
expect(onSearch).toHaveBeenCalledWith('query')
|
||||
|
|
@ -107,8 +107,8 @@ describe('Debounced Search', () => {
|
|||
|
||||
```typescript
|
||||
it('should retry on failure', async () => {
|
||||
jest.useFakeTimers()
|
||||
const fetchData = jest.fn()
|
||||
vi.useFakeTimers()
|
||||
const fetchData = vi.fn()
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockResolvedValueOnce({ data: 'success' })
|
||||
|
||||
|
|
@ -120,7 +120,7 @@ it('should retry on failure', async () => {
|
|||
})
|
||||
|
||||
// Advance timer for retry
|
||||
jest.advanceTimersByTime(1000)
|
||||
vi.advanceTimersByTime(1000)
|
||||
|
||||
// Second call succeeds
|
||||
await waitFor(() => {
|
||||
|
|
@ -128,7 +128,7 @@ it('should retry on failure', async () => {
|
|||
expect(screen.getByText('success')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
jest.useRealTimers()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
```
|
||||
|
||||
|
|
@ -136,19 +136,19 @@ it('should retry on failure', async () => {
|
|||
|
||||
```typescript
|
||||
// Run all pending timers
|
||||
jest.runAllTimers()
|
||||
vi.runAllTimers()
|
||||
|
||||
// Run only pending timers (not new ones created during execution)
|
||||
jest.runOnlyPendingTimers()
|
||||
vi.runOnlyPendingTimers()
|
||||
|
||||
// Advance by specific time
|
||||
jest.advanceTimersByTime(1000)
|
||||
vi.advanceTimersByTime(1000)
|
||||
|
||||
// Get current fake time
|
||||
jest.now()
|
||||
Date.now()
|
||||
|
||||
// Clear all timers
|
||||
jest.clearAllTimers()
|
||||
vi.clearAllTimers()
|
||||
```
|
||||
|
||||
## API Testing Patterns
|
||||
|
|
@ -158,7 +158,7 @@ jest.clearAllTimers()
|
|||
```typescript
|
||||
describe('DataFetcher', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should show loading state', () => {
|
||||
|
|
@ -241,7 +241,7 @@ it('should submit form and show success', async () => {
|
|||
|
||||
```typescript
|
||||
it('should fetch data on mount', async () => {
|
||||
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
|
||||
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||
|
||||
render(<ComponentWithEffect fetchData={fetchData} />)
|
||||
|
||||
|
|
@ -255,7 +255,7 @@ it('should fetch data on mount', async () => {
|
|||
|
||||
```typescript
|
||||
it('should refetch when id changes', async () => {
|
||||
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
|
||||
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||
|
||||
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
||||
|
||||
|
|
@ -276,8 +276,8 @@ it('should refetch when id changes', async () => {
|
|||
|
||||
```typescript
|
||||
it('should cleanup subscription on unmount', () => {
|
||||
const subscribe = jest.fn()
|
||||
const unsubscribe = jest.fn()
|
||||
const subscribe = vi.fn()
|
||||
const unsubscribe = vi.fn()
|
||||
subscribe.mockReturnValue(unsubscribe)
|
||||
|
||||
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
||||
|
|
@ -332,14 +332,14 @@ expect(description).toBeInTheDocument()
|
|||
|
||||
```typescript
|
||||
// Bad - fake timers don't work well with real Promises
|
||||
jest.useFakeTimers()
|
||||
vi.useFakeTimers()
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
}) // May timeout!
|
||||
|
||||
// Good - use runAllTimers or advanceTimersByTime
|
||||
jest.useFakeTimers()
|
||||
vi.useFakeTimers()
|
||||
render(<Component />)
|
||||
jest.runAllTimers()
|
||||
vi.runAllTimers()
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
```
|
||||
|
|
@ -74,9 +74,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
|
|||
### Mocks
|
||||
|
||||
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||
- [ ] Shared mock state reset in `beforeEach`
|
||||
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
|
||||
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
|
||||
- [ ] Router mocks match actual Next.js API
|
||||
- [ ] Mocks reflect actual component conditional behavior
|
||||
- [ ] Only mock: API services, complex context providers, third-party libs
|
||||
|
|
@ -132,10 +132,10 @@ For the current file being tested:
|
|||
|
||||
```typescript
|
||||
// ❌ Mock doesn't match actual behavior
|
||||
jest.mock('./Component', () => () => <div>Mocked</div>)
|
||||
vi.mock('./Component', () => () => <div>Mocked</div>)
|
||||
|
||||
// ✅ Mock matches actual conditional logic
|
||||
jest.mock('./Component', () => ({ isOpen }: any) =>
|
||||
vi.mock('./Component', () => ({ isOpen }: any) =>
|
||||
isOpen ? <div>Content</div> : null
|
||||
)
|
||||
```
|
||||
|
|
@ -145,7 +145,7 @@ jest.mock('./Component', () => ({ isOpen }: any) =>
|
|||
```typescript
|
||||
// ❌ Shared state not reset
|
||||
let mockState = false
|
||||
jest.mock('./useHook', () => () => mockState)
|
||||
vi.mock('./useHook', () => () => mockState)
|
||||
|
||||
// ✅ Reset in beforeEach
|
||||
beforeEach(() => {
|
||||
|
|
@ -192,7 +192,7 @@ pnpm test -- path/to/file.spec.tsx
|
|||
pnpm test -- --coverage path/to/file.spec.tsx
|
||||
|
||||
# Watch mode
|
||||
pnpm test -- --watch path/to/file.spec.tsx
|
||||
pnpm test:watch -- path/to/file.spec.tsx
|
||||
|
||||
# Update snapshots (use sparingly)
|
||||
pnpm test -- -u path/to/file.spec.tsx
|
||||
|
|
@ -126,7 +126,7 @@ describe('Counter', () => {
|
|||
describe('ControlledInput', () => {
|
||||
it('should call onChange with new value', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleChange = jest.fn()
|
||||
const handleChange = vi.fn()
|
||||
|
||||
render(<ControlledInput value="" onChange={handleChange} />)
|
||||
|
||||
|
|
@ -136,7 +136,7 @@ describe('ControlledInput', () => {
|
|||
})
|
||||
|
||||
it('should display controlled value', () => {
|
||||
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
|
||||
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
|
||||
|
||||
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
||||
})
|
||||
|
|
@ -195,7 +195,7 @@ describe('ItemList', () => {
|
|||
|
||||
it('should handle item selection', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = jest.fn()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(<ItemList items={items} onSelect={onSelect} />)
|
||||
|
||||
|
|
@ -217,20 +217,20 @@ describe('ItemList', () => {
|
|||
```typescript
|
||||
describe('Modal', () => {
|
||||
it('should not render when closed', () => {
|
||||
render(<Modal isOpen={false} onClose={jest.fn()} />)
|
||||
render(<Modal isOpen={false} onClose={vi.fn()} />)
|
||||
|
||||
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render when open', () => {
|
||||
render(<Modal isOpen={true} onClose={jest.fn()} />)
|
||||
render(<Modal isOpen={true} onClose={vi.fn()} />)
|
||||
|
||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onClose when clicking overlay', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClose = jest.fn()
|
||||
const handleClose = vi.fn()
|
||||
|
||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||
|
||||
|
|
@ -241,7 +241,7 @@ describe('Modal', () => {
|
|||
|
||||
it('should call onClose when pressing Escape', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClose = jest.fn()
|
||||
const handleClose = vi.fn()
|
||||
|
||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||
|
||||
|
|
@ -254,7 +254,7 @@ describe('Modal', () => {
|
|||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<Modal isOpen={true} onClose={jest.fn()}>
|
||||
<Modal isOpen={true} onClose={vi.fn()}>
|
||||
<button>First</button>
|
||||
<button>Second</button>
|
||||
</Modal>
|
||||
|
|
@ -279,7 +279,7 @@ describe('Modal', () => {
|
|||
describe('LoginForm', () => {
|
||||
it('should submit valid form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = jest.fn()
|
||||
const onSubmit = vi.fn()
|
||||
|
||||
render(<LoginForm onSubmit={onSubmit} />)
|
||||
|
||||
|
|
@ -296,7 +296,7 @@ describe('LoginForm', () => {
|
|||
it('should show validation errors', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<LoginForm onSubmit={jest.fn()} />)
|
||||
render(<LoginForm onSubmit={vi.fn()} />)
|
||||
|
||||
// Submit empty form
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
|
@ -308,7 +308,7 @@ describe('LoginForm', () => {
|
|||
it('should validate email format', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<LoginForm onSubmit={jest.fn()} />)
|
||||
render(<LoginForm onSubmit={vi.fn()} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
|
@ -318,7 +318,7 @@ describe('LoginForm', () => {
|
|||
|
||||
it('should disable submit button while submitting', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
||||
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
||||
|
||||
render(<LoginForm onSubmit={onSubmit} />)
|
||||
|
||||
|
|
@ -407,7 +407,7 @@ it('test 1', () => {
|
|||
|
||||
// Good - cleanup is automatic with RTL, but reset mocks
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
```
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ import NodeConfigPanel from './node-config-panel'
|
|||
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
||||
|
||||
// Mock workflow context
|
||||
jest.mock('@/app/components/workflow/hooks', () => ({
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useWorkflowStore: () => mockWorkflowStore,
|
||||
useNodesInteractions: () => mockNodesInteractions,
|
||||
}))
|
||||
|
|
@ -31,21 +31,21 @@ jest.mock('@/app/components/workflow/hooks', () => ({
|
|||
let mockWorkflowStore = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
updateNode: jest.fn(),
|
||||
updateNode: vi.fn(),
|
||||
}
|
||||
|
||||
let mockNodesInteractions = {
|
||||
handleNodeSelect: jest.fn(),
|
||||
handleNodeDelete: jest.fn(),
|
||||
handleNodeSelect: vi.fn(),
|
||||
handleNodeDelete: vi.fn(),
|
||||
}
|
||||
|
||||
describe('NodeConfigPanel', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
mockWorkflowStore = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
updateNode: jest.fn(),
|
||||
updateNode: vi.fn(),
|
||||
}
|
||||
})
|
||||
|
||||
|
|
@ -161,23 +161,23 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
|||
import userEvent from '@testing-library/user-event'
|
||||
import DocumentUploader from './document-uploader'
|
||||
|
||||
jest.mock('@/service/datasets', () => ({
|
||||
uploadDocument: jest.fn(),
|
||||
parseDocument: jest.fn(),
|
||||
vi.mock('@/service/datasets', () => ({
|
||||
uploadDocument: vi.fn(),
|
||||
parseDocument: vi.fn(),
|
||||
}))
|
||||
|
||||
import * as datasetService from '@/service/datasets'
|
||||
const mockedService = datasetService as jest.Mocked<typeof datasetService>
|
||||
const mockedService = vi.mocked(datasetService)
|
||||
|
||||
describe('DocumentUploader', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('File Upload', () => {
|
||||
it('should accept valid file types', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpload = jest.fn()
|
||||
const onUpload = vi.fn()
|
||||
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
||||
|
||||
render(<DocumentUploader onUpload={onUpload} />)
|
||||
|
|
@ -326,14 +326,14 @@ describe('DocumentList', () => {
|
|||
describe('Search & Filtering', () => {
|
||||
it('should filter by search query', async () => {
|
||||
const user = userEvent.setup()
|
||||
jest.useFakeTimers()
|
||||
vi.useFakeTimers()
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
||||
|
||||
// Debounce
|
||||
jest.advanceTimersByTime(300)
|
||||
vi.advanceTimersByTime(300)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
||||
|
|
@ -342,7 +342,7 @@ describe('DocumentList', () => {
|
|||
)
|
||||
})
|
||||
|
||||
jest.useRealTimers()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -367,13 +367,13 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
|||
import userEvent from '@testing-library/user-event'
|
||||
import AppConfigForm from './app-config-form'
|
||||
|
||||
jest.mock('@/service/apps', () => ({
|
||||
updateAppConfig: jest.fn(),
|
||||
getAppConfig: jest.fn(),
|
||||
vi.mock('@/service/apps', () => ({
|
||||
updateAppConfig: vi.fn(),
|
||||
getAppConfig: vi.fn(),
|
||||
}))
|
||||
|
||||
import * as appService from '@/service/apps'
|
||||
const mockedService = appService as jest.Mocked<typeof appService>
|
||||
const mockedService = vi.mocked(appService)
|
||||
|
||||
describe('AppConfigForm', () => {
|
||||
const defaultConfig = {
|
||||
|
|
@ -384,7 +384,7 @@ describe('AppConfigForm', () => {
|
|||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
||||
})
|
||||
|
||||
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components
|
||||
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
|
@ -41,20 +41,23 @@ Only mock these categories:
|
|||
|
||||
| Location | Purpose |
|
||||
|----------|---------|
|
||||
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
|
||||
| Test file | Test-specific mocks, inline with `jest.mock()` |
|
||||
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
|
||||
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
|
||||
| Test file | Test-specific mocks, inline with `vi.mock()` |
|
||||
|
||||
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
|
||||
|
||||
## Essential Mocks
|
||||
|
||||
### 1. i18n (Auto-loaded via Shared Mock)
|
||||
### 1. i18n (Auto-loaded via Global Mock)
|
||||
|
||||
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
|
||||
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
||||
|
||||
For tests requiring custom translations, override the mock:
|
||||
|
||||
```typescript
|
||||
jest.mock('react-i18next', () => ({
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
|
|
@ -69,15 +72,15 @@ jest.mock('react-i18next', () => ({
|
|||
### 2. Next.js Router
|
||||
|
||||
```typescript
|
||||
const mockPush = jest.fn()
|
||||
const mockReplace = jest.fn()
|
||||
const mockPush = vi.fn()
|
||||
const mockReplace = vi.fn()
|
||||
|
||||
jest.mock('next/navigation', () => ({
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
replace: mockReplace,
|
||||
back: jest.fn(),
|
||||
prefetch: jest.fn(),
|
||||
back: vi.fn(),
|
||||
prefetch: vi.fn(),
|
||||
}),
|
||||
usePathname: () => '/current-path',
|
||||
useSearchParams: () => new URLSearchParams('?key=value'),
|
||||
|
|
@ -85,7 +88,7 @@ jest.mock('next/navigation', () => ({
|
|||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should navigate on click', () => {
|
||||
|
|
@ -102,7 +105,7 @@ describe('Component', () => {
|
|||
// ⚠️ Important: Use shared state for components that depend on each other
|
||||
let mockPortalOpenState = false
|
||||
|
||||
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
||||
mockPortalOpenState = open || false // Update shared state
|
||||
return <div data-testid="portal" data-open={open}>{children}</div>
|
||||
|
|
@ -119,7 +122,7 @@ jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
|||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
mockPortalOpenState = false // ✅ Reset shared state
|
||||
})
|
||||
})
|
||||
|
|
@ -130,13 +133,13 @@ describe('Component', () => {
|
|||
```typescript
|
||||
import * as api from '@/service/api'
|
||||
|
||||
jest.mock('@/service/api')
|
||||
vi.mock('@/service/api')
|
||||
|
||||
const mockedApi = api as jest.Mocked<typeof api>
|
||||
const mockedApi = vi.mocked(api)
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Setup default mock implementation
|
||||
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
||||
|
|
@ -243,13 +246,13 @@ describe('Component with Context', () => {
|
|||
|
||||
```typescript
|
||||
// SWR
|
||||
jest.mock('swr', () => ({
|
||||
vi.mock('swr', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(),
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
import useSWR from 'swr'
|
||||
const mockedUseSWR = useSWR as jest.Mock
|
||||
const mockedUseSWR = vi.mocked(useSWR)
|
||||
|
||||
describe('Component with SWR', () => {
|
||||
it('should show loading state', () => {
|
||||
|
|
@ -0,0 +1 @@
|
|||
../.claude/skills
|
||||
|
|
@ -6,7 +6,7 @@ cd web && pnpm install
|
|||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
|
|
|||
|
|
@ -6,6 +6,12 @@
|
|||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# CODEOWNERS file
|
||||
.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
# Docs
|
||||
docs/ @crazywoola
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
|
|
@ -116,11 +122,17 @@ api/controllers/console/feature.py @GarfieldDai @GareArc
|
|||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
api/migrations/ @snakevash @laipz8200 @MRZHUH
|
||||
|
||||
# Backend - Vector DB Middleware
|
||||
api/configs/middleware/vdb/* @JohnJyong
|
||||
|
||||
# Frontend
|
||||
web/ @iamjoel
|
||||
|
||||
# Frontend - Web Tests
|
||||
.github/workflows/web-tests.yml @iamjoel
|
||||
|
||||
# Frontend - App - Orchestration
|
||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
|
|
@ -192,6 +204,7 @@ web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
|||
web/app/signin/ @douxc @iamjoel
|
||||
web/app/signup/ @douxc @iamjoel
|
||||
web/app/reset-password/ @douxc @iamjoel
|
||||
|
||||
web/app/install/ @douxc @iamjoel
|
||||
web/app/init/ @douxc @iamjoel
|
||||
web/app/forgot-password/ @douxc @iamjoel
|
||||
|
|
@ -232,3 +245,6 @@ web/app/education-apply/ @iamjoel @zxhlyh
|
|||
|
||||
# Frontend - Workspace
|
||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
|
||||
# Docker
|
||||
docker/* @laipz8200
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ jobs:
|
|||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
|
||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
|
|
@ -79,7 +79,7 @@ jobs:
|
|||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
working-directory: ./web
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ jobs:
|
|||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ jobs:
|
|||
with:
|
||||
node-version: 'lts/*'
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
|
|
@ -21,14 +22,7 @@ jobs:
|
|||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
|
|
@ -36,23 +30,342 @@ jobs:
|
|||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm test
|
||||
run: pnpm test --coverage
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
set -eo pipefail
|
||||
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
let libCoverage = null;
|
||||
|
||||
try {
|
||||
libCoverage = require('istanbul-lib-coverage');
|
||||
} catch (error) {
|
||||
libCoverage = null;
|
||||
}
|
||||
|
||||
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||
|
||||
const hasSummary = fs.existsSync(summaryPath);
|
||||
const hasFinal = fs.existsSync(finalPath);
|
||||
|
||||
if (!hasSummary && !hasFinal) {
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('No coverage data found.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const summary = hasSummary
|
||||
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||
: null;
|
||||
const coverage = hasFinal
|
||||
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||
: null;
|
||||
|
||||
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||
const lineHits = {};
|
||||
|
||||
if (!statementMap || !statementHits) {
|
||||
return lineHits;
|
||||
}
|
||||
|
||||
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||
const line = statement?.start?.line;
|
||||
if (!line) {
|
||||
return;
|
||||
}
|
||||
const hits = statementHits[key] ?? 0;
|
||||
const previous = lineHits[line];
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||
});
|
||||
|
||||
return lineHits;
|
||||
};
|
||||
|
||||
const getFileCoverage = (entry) => (
|
||||
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||
);
|
||||
|
||||
const getLineHits = (entry, fileCoverage) => {
|
||||
const lineHits = entry.l ?? {};
|
||||
if (Object.keys(lineHits).length > 0) {
|
||||
return lineHits;
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getLineCoverage();
|
||||
}
|
||||
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||
};
|
||||
|
||||
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||
return Object.entries(lineHits)
|
||||
.filter(([, count]) => count === 0)
|
||||
.map(([line]) => Number(line))
|
||||
.sort((a, b) => a - b);
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getUncoveredLines();
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const totals = {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
};
|
||||
const fileSummaries = [];
|
||||
|
||||
if (summary) {
|
||||
const totalEntry = summary.total ?? {};
|
||||
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||
if (totalEntry[key]) {
|
||||
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||
totals[key].total = totalEntry[key].total ?? 0;
|
||||
}
|
||||
});
|
||||
|
||||
Object.entries(summary)
|
||||
.filter(([file]) => file !== 'total')
|
||||
.forEach(([file, data]) => {
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||
lines: {
|
||||
covered: data.lines?.covered ?? 0,
|
||||
total: data.lines?.total ?? 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
} else if (coverage) {
|
||||
Object.entries(coverage).forEach(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
totals.lines.total += lineTotal;
|
||||
totals.lines.covered += lineCovered;
|
||||
totals.statements.total += statementTotal;
|
||||
totals.statements.covered += statementCovered;
|
||||
totals.branches.total += branchTotal;
|
||||
totals.branches.covered += branchCovered;
|
||||
totals.functions.total += functionTotal;
|
||||
totals.functions.covered += functionCovered;
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||
lines: {
|
||||
covered: lineCovered || statementCovered,
|
||||
total: lineTotal || statementTotal,
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('| Metric | Coverage | Covered / Total |');
|
||||
console.log('|--------|----------|-----------------|');
|
||||
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||
console.log('');
|
||||
console.log('```');
|
||||
fileSummaries
|
||||
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||
.slice(0, 25)
|
||||
.forEach(({ file, pct, lines }) => {
|
||||
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||
});
|
||||
console.log('```');
|
||||
console.log('</details>');
|
||||
|
||||
if (coverage) {
|
||||
const pctValue = (covered, tot) => {
|
||||
if (tot === 0) {
|
||||
return '0';
|
||||
}
|
||||
return ((covered / tot) * 100)
|
||||
.toFixed(2)
|
||||
.replace(/\.?0+$/, '');
|
||||
};
|
||||
|
||||
const formatLineRanges = (lines) => {
|
||||
if (lines.length === 0) {
|
||||
return '';
|
||||
}
|
||||
const ranges = [];
|
||||
let start = lines[0];
|
||||
let end = lines[0];
|
||||
|
||||
for (let i = 1; i < lines.length; i += 1) {
|
||||
const current = lines[i];
|
||||
if (current === end + 1) {
|
||||
end = current;
|
||||
continue;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
start = current;
|
||||
end = current;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
return ranges.join(',');
|
||||
};
|
||||
|
||||
const tableTotals = {
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
lines: { covered: 0, total: 0 },
|
||||
};
|
||||
const tableRows = Object.entries(coverage)
|
||||
.map(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
tableTotals.lines.total += lineTotal;
|
||||
tableTotals.lines.covered += lineCovered;
|
||||
tableTotals.statements.total += statementTotal;
|
||||
tableTotals.statements.covered += statementCovered;
|
||||
tableTotals.branches.total += branchTotal;
|
||||
tableTotals.branches.covered += branchCovered;
|
||||
tableTotals.functions.total += functionTotal;
|
||||
tableTotals.functions.covered += functionCovered;
|
||||
|
||||
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||
|
||||
const filePath = entry.path ?? file;
|
||||
const relativePath = path.isAbsolute(filePath)
|
||||
? path.relative(process.cwd(), filePath)
|
||||
: filePath;
|
||||
|
||||
return {
|
||||
file: relativePath || file,
|
||||
statements: pctValue(statementCovered, statementTotal),
|
||||
branches: pctValue(branchCovered, branchTotal),
|
||||
functions: pctValue(functionCovered, functionTotal),
|
||||
lines: pctValue(lineCovered, lineTotal),
|
||||
uncovered: formatLineRanges(uncoveredLines),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => a.file.localeCompare(b.file));
|
||||
|
||||
const columns = [
|
||||
{ key: 'file', header: 'File', align: 'left' },
|
||||
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||
];
|
||||
|
||||
const allFilesRow = {
|
||||
file: 'All files',
|
||||
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||
uncovered: '',
|
||||
};
|
||||
|
||||
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||
const formatRow = (row) => `| ${columns
|
||||
.map(({ key }) => String(row[key] ?? ''))
|
||||
.join(' | ')} |`;
|
||||
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||
const dividerRow = `| ${columns
|
||||
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||
.join(' | ')} |`;
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>Vitest coverage table</summary>');
|
||||
console.log('');
|
||||
console.log(headerRow);
|
||||
console.log(dividerRow);
|
||||
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||
console.log('</details>');
|
||||
}
|
||||
NODE
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
retention-days: 30
|
||||
if-no-files-found: error
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@
|
|||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
|
|
|||
|
|
@ -690,3 +690,8 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
|
|||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
||||
# Maximum number of concurrent annotation import tasks per tenant
|
||||
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
|
||||
# Sandbox expired records clean configuration
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@
|
|||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
|
|||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=300.0,
|
||||
default=600.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
|
@ -1270,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
|
|
@ -1295,6 +1310,7 @@ class FeatureConfig(
|
|||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
|
|||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from .. import console_ns
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
class CompletionMessageExplorePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
|
@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
|
|||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
|
|
@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
|||
endpoint="installed_app_completion",
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
|
||||
payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
|
|
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
|
|
@ -24,7 +24,7 @@ from .. import console_ns
|
|||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
pinned: bool | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
|
|
@ -18,6 +19,15 @@ from services.account_service import TenantService
|
|||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class InstalledAppCreatePayload(BaseModel):
|
||||
app_id: str
|
||||
|
||||
|
||||
class InstalledAppUpdatePayload(BaseModel):
|
||||
is_pinned: bool | None = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||
args = parser.parse_args()
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||
if recommended_app is None:
|
||||
raise NotFound("App not found")
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||
|
||||
if app is None:
|
||||
raise NotFound("App not found")
|
||||
raise NotFound("App entity not found")
|
||||
|
||||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
|
||||
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
|
|||
recommended_app.install_count += 1
|
||||
|
||||
new_installed_app = InstalledApp(
|
||||
app_id=args["app_id"],
|
||||
app_id=payload.app_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
is_pinned=False,
|
||||
|
|
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
|
|||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||
|
||||
def patch(self, installed_app):
|
||||
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
|
||||
args = parser.parse_args()
|
||||
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
commit_args = False
|
||||
if "is_pinned" in args:
|
||||
installed_app.is_pinned = args["is_pinned"]
|
||||
if payload.is_pinned is not None:
|
||||
installed_app.is_pinned = payload.is_pinned
|
||||
commit_args = True
|
||||
|
||||
if commit_args:
|
||||
|
|
|
|||
|
|
@ -1,31 +1,40 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
parser_tags = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||
target_id: str = Field(description="Target ID to bind tags to")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -43,7 +52,7 @@ class TagListApi(Resource):
|
|||
|
||||
return tags, 200
|
||||
|
||||
@console_ns.expect(parser_tags)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -53,22 +62,17 @@ class TagListApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_tags.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
parser_tag_id = reqparse.RequestParser().add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(parser_tag_id)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_tag_id.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
|
|
@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource):
|
|||
return 204
|
||||
|
||||
|
||||
parser_create = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(parser_create)
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_create.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_remove = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(parser_remove)
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_remove.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
|
|
@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider
|
||||
# Create provider in transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
|
|
@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@console_ns.expect(parser_mcp_put)
|
||||
@setup_required
|
||||
|
|
@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||
validation_result = None
|
||||
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||
validation_data = None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_result = service.validate_server_url_change(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||
validation_data = service.get_provider_for_url_validation(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"]
|
||||
)
|
||||
|
||||
# No need to check for errors here, exceptions will be raised directly
|
||||
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
|
||||
# This prevents holding database locks during potentially slow network operations
|
||||
validation_result = MCPToolManageService.validate_server_url_standalone(
|
||||
tenant_id=current_tenant_id,
|
||||
new_server_url=args["server_url"],
|
||||
validation_data=validation_data,
|
||||
)
|
||||
|
||||
# Step 2: Perform database update in a transaction
|
||||
# Step 3: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
|
|
@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_mcp_delete)
|
||||
@setup_required
|
||||
|
|
@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource):
|
|||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_auth = (
|
||||
|
|
@ -1062,6 +1081,8 @@ class ToolMCPAuthApi(Resource):
|
|||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
# Invalidate cache after updating credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
|
|
@ -1075,16 +1096,22 @@ class ToolMCPAuthApi(Resource):
|
|||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
# Invalidate cache after auth actions may have updated provider state
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel):
|
|||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, field_validator
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
|
|
@ -20,6 +21,7 @@ from controllers.web.error import (
|
|||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
|
|
@ -29,6 +31,25 @@ from services.errors.audio import (
|
|||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, TextToAudioPayload)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -88,6 +109,7 @@ class AudioApi(WebApiResource):
|
|||
|
||||
@web_ns.route("/text-to-audio")
|
||||
class TextApi(WebApiResource):
|
||||
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
|
||||
@web_ns.doc("Text to Audio")
|
||||
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@web_ns.doc(
|
||||
|
|
@ -102,18 +124,11 @@ class TextApi(WebApiResource):
|
|||
def post(self, app_model: App, end_user):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = payload.message_id
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
|
|
@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the completion")
|
||||
query: str = Field(default="", description="Query text for completion")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
|
||||
response_mode: Literal["blocking", "streaming"] | None = Field(
|
||||
default=None, description="Response mode: blocking or streaming"
|
||||
)
|
||||
retriever_from: str = Field(default="web_app", description="Source of retriever")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the chat")
|
||||
query: str = Field(description="User query/message")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
|
||||
response_mode: Literal["blocking", "streaming"] | None = Field(
|
||||
default=None, description="Response mode: blocking or streaming"
|
||||
)
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
retriever_from: str = Field(default="web_app", description="Source of retriever")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@web_ns.route("/completion-messages")
|
||||
class CompletionApi(WebApiResource):
|
||||
@web_ns.doc("Create Completion Message")
|
||||
@web_ns.doc(description="Create a completion message for text generation applications.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
|
||||
"query": {"description": "Query text for completion", "type": "string", "required": False},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
|
|
@ -67,18 +88,10 @@ class CompletionApi(WebApiResource):
|
|||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
)
|
||||
payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource):
|
|||
class ChatApi(WebApiResource):
|
||||
@web_ns.doc("Create Chat Message")
|
||||
@web_ns.doc(description="Create a chat message for conversational applications.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
|
||||
"query": {"description": "User query/message", "type": "string", "required": True},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
|
||||
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
|
|
@ -173,20 +171,10 @@ class ChatApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
)
|
||||
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
|
@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
|
|||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict[str, Any] | None = Field(default=None)
|
||||
json_schema: str | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
|
|||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
def validate_json_schema(cls, schema: str | None) -> str | None:
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"invalid json_schema value {schema}")
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(json_schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
|
|
@ -175,6 +176,13 @@ class BaseAppGenerator:
|
|||
value = True
|
||||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a string")
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
|
|
|||
|
|
@ -342,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=event_type,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from threading import Thread
|
|||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -54,6 +54,20 @@ class MessageCycleManager:
|
|||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
self._message_has_file: set[str] = set()
|
||||
|
||||
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
||||
if message_id in self._message_has_file:
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
|
||||
|
||||
if has_file:
|
||||
self._message_has_file.add(message_id)
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
return StreamEvent.MESSAGE
|
||||
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
|
||||
"""
|
||||
|
|
@ -214,7 +228,11 @@ class MessageCycleManager:
|
|||
return None
|
||||
|
||||
def message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
|
||||
self,
|
||||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
|
|
@ -222,16 +240,12 @@ class MessageCycleManager:
|
|||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
event=event_type,
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
|
|||
"""
|
||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
|
||||
Priority order:
|
||||
1. URL from WWW-Authenticate header (if provided)
|
||||
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
|
||||
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
|
||||
"""
|
||||
urls = []
|
||||
|
||||
|
|
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
|
|||
# Fallback: construct from server URL
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
if fallback_url not in urls:
|
||||
urls.append(fallback_url)
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
|
||||
if path:
|
||||
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
|
||||
if path_url not in urls:
|
||||
urls.append(path_url)
|
||||
|
||||
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
|
||||
root_url = f"{base_url}/.well-known/oauth-protected-resource"
|
||||
if root_url not in urls:
|
||||
urls.append(root_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
|
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
|
|||
|
||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||
|
||||
Per RFC 8414 section 3:
|
||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
||||
|
||||
Example:
|
||||
- issuer: https://example.com/oauth
|
||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
||||
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
|
||||
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
|
||||
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
|
||||
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
|
||||
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
|
||||
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
|
||||
"""
|
||||
urls = []
|
||||
base_url = auth_server_url or server_url
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
||||
path = parsed.path.rstrip("/")
|
||||
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
|
||||
urls.append(f"{base}/.well-known/oauth-authorization-server")
|
||||
|
||||
# Try OpenID Connect discovery first (more common)
|
||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
||||
# OpenID Connect Discovery at root
|
||||
urls.append(f"{base}/.well-known/openid-configuration")
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
||||
# Include the path component if present in the issuer URL
|
||||
if path:
|
||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
||||
else:
|
||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
||||
# OpenID Connect Discovery with path insertion
|
||||
urls.append(f"{base}/.well-known/openid-configuration{path}")
|
||||
|
||||
# OpenID Connect Discovery path appending
|
||||
urls.append(f"{base}{path}/.well-known/openid-configuration")
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata with path insertion
|
||||
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||
|
||||
return urls
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class MCPClient:
|
|||
try:
|
||||
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
except (MCPConnectionError, ValueError):
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
|
|
|
|||
|
|
@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model
|
|||
|
||||
- Model provider display
|
||||
|
||||

|
||||
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
|
||||
|
||||
- Selectable model list display
|
||||
|
||||

|
||||
|
||||
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
|
||||
|
||||
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
|
||||
|
||||

|
||||
|
||||
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
|
||||
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
|
||||
|
||||
- Provider/model credential authentication
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
|
||||
|
||||
## Structure
|
||||
|
||||

|
||||
|
||||
Model Runtime is divided into three layers:
|
||||
|
||||
- The outermost layer is the factory method
|
||||
|
|
@ -60,9 +46,6 @@ Model Runtime is divided into three layers:
|
|||
|
||||
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
||||
|
||||
## Next Steps
|
||||
## Documentation
|
||||
|
||||
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
|
||||
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
|
||||
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
|
||||
- Implement interface methods: [Link](./docs/en_US/interfaces.md)
|
||||
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).
|
||||
|
|
|
|||
|
|
@ -18,34 +18,20 @@
|
|||
|
||||
- 模型供应商展示
|
||||
|
||||

|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||

|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
|
||||
|
||||

|
||||
|
||||
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
|
||||
|
||||
- 供应商/模型凭据鉴权
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
|
||||
|
||||
## 结构
|
||||
|
||||

|
||||
|
||||
Model Runtime 分三层:
|
||||
|
||||
- 最外层为工厂方法
|
||||
|
|
@ -59,8 +45,7 @@ Model Runtime 分三层:
|
|||
对于供应商/模型凭据,有两种情况
|
||||
|
||||
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||

|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||
|
||||
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||
|
||||
|
|
@ -74,20 +59,6 @@ Model Runtime 分三层:
|
|||
|
||||
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
|
||||
|
||||
## 下一步
|
||||
## 文档
|
||||
|
||||
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
|
||||
|
||||
当添加后,这里将会出现一个新的供应商
|
||||
|
||||

|
||||
|
||||
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
|
||||
|
||||
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
|
||||
|
||||

|
||||
|
||||
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
|
||||
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from core.trigger.errors import (
|
|||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
_plugin_daemon_timeout_config = cast(
|
||||
float | httpx.Timeout | None,
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
|
||||
)
|
||||
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||
if _plugin_daemon_timeout_config is None:
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor):
|
|||
def _extract_images_from_docx(self, doc):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
for r_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.target_ref:
|
||||
|
|
@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor):
|
|||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
# Use r_id as key for external images since target_part is undefined
|
||||
image_map[r_id] = f""
|
||||
image_map[r_id] = f""
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
|
|
@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor):
|
|||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
# Use target_part as key for internal images
|
||||
image_map[rel.target_part] = (
|
||||
f""
|
||||
)
|
||||
image_map[rel.target_part] = f""
|
||||
db.session.commit()
|
||||
return image_map
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._fixed_separator = fixed_separator
|
||||
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
|
||||
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
|
|
@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
splits = re.split(r" +", text)
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
|
||||
if self._keep_separator:
|
||||
splits = [s + separator for s in splits[:-1]] + splits[-1:]
|
||||
else:
|
||||
splits = list(text)
|
||||
if separator == "\n":
|
||||
|
|
@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = separator if self._keep_separator else ""
|
||||
_separator = "" if self._keep_separator else separator
|
||||
s_lens = self._length_function(splits)
|
||||
if separator != "":
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
|
|
|
|||
|
|
@ -86,6 +86,11 @@ class Executor:
|
|||
node_data.authorization.config.api_key = variable_pool.convert_template(
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
# Validate that API key is not empty after template conversion
|
||||
if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip():
|
||||
raise AuthorizationConfigError(
|
||||
"API key is required for authorization but was empty. Please provide a valid API key."
|
||||
)
|
||||
|
||||
self.url = node_data.url
|
||||
self.method = node_data.method
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
|
|
@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]):
|
|||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{key} must be a JSON object")
|
||||
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
Draft7Validator(schema).validate(value)
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{schema} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{value} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
Draft7Validator(json_schema).validate(json_value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = value
|
||||
node_inputs[key] = json_value
|
||||
|
|
|
|||
|
|
@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
|
@ -69,6 +69,53 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||
|
||||
elif [[ "${MODE}" == "job" ]]; then
|
||||
# Job mode: Run a one-time Flask command and exit
|
||||
# Pass Flask command and arguments via container args
|
||||
# Example K8s usage:
|
||||
# args:
|
||||
# - create-tenant
|
||||
# - --email
|
||||
# - admin@example.com
|
||||
#
|
||||
# Example Docker usage:
|
||||
# docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "Error: No command specified for job mode."
|
||||
echo ""
|
||||
echo "Usage examples:"
|
||||
echo " Kubernetes:"
|
||||
echo " args: [create-tenant, --email, admin@example.com]"
|
||||
echo ""
|
||||
echo " Docker:"
|
||||
echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com"
|
||||
echo ""
|
||||
echo "Available commands:"
|
||||
echo " create-tenant, reset-password, reset-email, upgrade-db,"
|
||||
echo " vdb-migrate, install-plugins, and more..."
|
||||
echo ""
|
||||
echo "Run 'flask --help' to see all available commands."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Running Flask job command: flask $*"
|
||||
|
||||
# Temporarily disable exit on error to capture exit code
|
||||
set +e
|
||||
flask "$@"
|
||||
JOB_EXIT_CODE=$?
|
||||
set -e
|
||||
|
||||
if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then
|
||||
echo "Job completed successfully."
|
||||
else
|
||||
echo "Job failed with exit code ${JOB_EXIT_CODE}."
|
||||
fi
|
||||
|
||||
exit ${JOB_EXIT_CODE}
|
||||
|
||||
else
|
||||
if [[ "${DEBUG}" == "true" ]]; then
|
||||
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
|
||||
|
|
|
|||
|
|
@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage):
|
|||
if not self.exists(path):
|
||||
raise FileNotFoundError("Path not found")
|
||||
|
||||
all_files = self.op.scan(path=path)
|
||||
# Use the new OpenDAL 0.46.0+ API with recursive listing
|
||||
lister = self.op.list(path, recursive=True)
|
||||
if files and directories:
|
||||
logger.debug("files and directories on %s scanned", path)
|
||||
return [f.path for f in all_files]
|
||||
return [entry.path for entry in lister]
|
||||
if files:
|
||||
logger.debug("files on %s scanned", path)
|
||||
return [f.path for f in all_files if not f.path.endswith("/")]
|
||||
return [entry.path for entry in lister if not entry.metadata.is_dir]
|
||||
elif directories:
|
||||
logger.debug("directories on %s scanned", path)
|
||||
return [f.path for f in all_files if f.path.endswith("/")]
|
||||
return [entry.path for entry in lister if entry.metadata.is_dir]
|
||||
else:
|
||||
raise ValueError("At least one of files or directories must be True")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping
|
|||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
|
|
@ -119,6 +120,19 @@ def uuid_value(value: Any) -> str:
|
|||
raise ValueError(error)
|
||||
|
||||
|
||||
def normalize_uuid(value: str | UUID) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)]
|
||||
|
||||
|
||||
def alphanumeric(value: str):
|
||||
# check if the value is alphanumeric and underlined
|
||||
if re.match(r"^[a-zA-Z0-9_]+$", value):
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ dependencies = [
|
|||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.5.2",
|
||||
"chardet~=5.1.0",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.18",
|
||||
"flask-cors~=6.0.0",
|
||||
|
|
@ -32,6 +32,7 @@ dependencies = [
|
|||
"httpx[socks]~=0.27.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.41.1",
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"markdown~=3.5.1",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
from typing_extensions import TypedDict
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
|
@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client
|
|||
from libs.helper import RateLimiter
|
||||
from models import Account, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubscriptionPlan(TypedDict):
|
||||
"""Tenant subscriptionplan information."""
|
||||
|
||||
plan: str
|
||||
expiration_date: int
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
|
|
@ -239,3 +252,39 @@ class BillingService:
|
|||
def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str):
|
||||
payload = {"account_id": account_id, "click_id": click_id}
|
||||
return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload)
|
||||
|
||||
@classmethod
|
||||
def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
|
||||
"""
|
||||
Bulk fetch billing subscription plan via billing API.
|
||||
Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request)
|
||||
Returns:
|
||||
Mapping of tenant_id -> {plan: str, expiration_date: int}
|
||||
"""
|
||||
results: dict[str, SubscriptionPlan] = {}
|
||||
subscription_adapter = TypeAdapter(SubscriptionPlan)
|
||||
|
||||
chunk_size = 200
|
||||
for i in range(0, len(tenant_ids), chunk_size):
|
||||
chunk = tenant_ids[i : i + chunk_size]
|
||||
try:
|
||||
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
|
||||
data = resp.get("data", {})
|
||||
|
||||
for tenant_id, plan in data.items():
|
||||
subscription_plan = subscription_adapter.validate_python(plan)
|
||||
results[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
|
||||
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
|
||||
data = resp.get("data", [])
|
||||
tenant_whitelist = []
|
||||
for item in data:
|
||||
tenant_whitelist.append(item["tenant_id"])
|
||||
return tenant_whitelist
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel):
|
|||
description: str
|
||||
icon_info: IconInfo
|
||||
permission: str
|
||||
partial_member_list: list[str] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
yaml_content: str | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
|
|||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
|
|
@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
|
|||
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||
|
||||
|
||||
class ProviderUrlValidationData(BaseModel):
|
||||
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
|
||||
|
||||
current_server_url_hash: str
|
||||
headers: dict[str, str]
|
||||
timeout: float | None
|
||||
sse_read_timeout: float | None
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""Service class for managing MCP tools and providers."""
|
||||
|
||||
|
|
@ -166,9 +174,6 @@ class MCPToolManageService:
|
|||
self._session.add(mcp_tool)
|
||||
self._session.flush()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
return mcp_providers
|
||||
|
||||
|
|
@ -192,7 +197,7 @@ class MCPToolManageService:
|
|||
Update an MCP provider.
|
||||
|
||||
Args:
|
||||
validation_result: Pre-validation result from validate_server_url_change.
|
||||
validation_result: Pre-validation result from validate_server_url_standalone.
|
||||
If provided and contains reconnect_result, it will be used
|
||||
instead of performing network operations.
|
||||
"""
|
||||
|
|
@ -251,8 +256,6 @@ class MCPToolManageService:
|
|||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
except IntegrityError as e:
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
|
||||
|
|
@ -261,9 +264,6 @@ class MCPToolManageService:
|
|||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
def list_providers(
|
||||
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||
) -> list[ToolProviderApiEntity]:
|
||||
|
|
@ -546,30 +546,39 @@ class MCPToolManageService:
|
|||
)
|
||||
return self.execute_auth_actions(auth_result)
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
|
||||
"""
|
||||
Get provider data required for URL validation.
|
||||
This method performs database read and should be called within a session.
|
||||
|
||||
Returns:
|
||||
ProviderUrlValidationData: Data needed for standalone URL validation
|
||||
"""
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
provider_entity = provider.to_entity()
|
||||
headers = provider_entity.headers
|
||||
return ProviderUrlValidationData(
|
||||
current_server_url_hash=provider.server_url_hash,
|
||||
headers=provider_entity.headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def validate_server_url_change(
|
||||
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
||||
@staticmethod
|
||||
def validate_server_url_standalone(
|
||||
*,
|
||||
tenant_id: str,
|
||||
new_server_url: str,
|
||||
validation_data: ProviderUrlValidationData,
|
||||
) -> ServerUrlValidationResult:
|
||||
"""
|
||||
Validate server URL change by attempting to connect to the new server.
|
||||
This method should be called BEFORE update_provider to perform network operations
|
||||
outside of the database transaction.
|
||||
This method performs network operations and MUST be called OUTSIDE of any database session
|
||||
to avoid holding locks during network I/O.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for encryption
|
||||
new_server_url: The new server URL to validate
|
||||
validation_data: Provider data obtained from get_provider_for_url_validation
|
||||
|
||||
Returns:
|
||||
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||
|
|
@ -579,25 +588,30 @@ class MCPToolManageService:
|
|||
return ServerUrlValidationResult(needs_validation=False)
|
||||
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(new_server_url):
|
||||
parsed = urlparse(new_server_url)
|
||||
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Always encrypt and hash the URL
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||
|
||||
# Get current provider
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Check if URL is actually different
|
||||
if new_server_url_hash == provider.server_url_hash:
|
||||
if new_server_url_hash == validation_data.current_server_url_hash:
|
||||
# URL hasn't changed, but still return the encrypted data
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
||||
needs_validation=False,
|
||||
encrypted_server_url=encrypted_server_url,
|
||||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
# Perform validation by attempting to connect
|
||||
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
||||
# Perform network validation - this is the expensive operation that should be outside session
|
||||
reconnect_result = MCPToolManageService._reconnect_with_url(
|
||||
server_url=new_server_url,
|
||||
headers=validation_data.headers,
|
||||
timeout=validation_data.timeout,
|
||||
sse_read_timeout=validation_data.sse_read_timeout,
|
||||
)
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=True,
|
||||
validation_passed=True,
|
||||
|
|
@ -606,6 +620,38 @@ class MCPToolManageService:
|
|||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reconnect_with_url(
|
||||
*,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
timeout: float | None,
|
||||
sse_read_timeout: float | None,
|
||||
) -> ReconnectResult:
|
||||
"""
|
||||
Attempt to connect to MCP server with given URL.
|
||||
This is a static method that performs network I/O without database access.
|
||||
"""
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
) -> ToolProviderApiEntity:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import logging
|
|||
import time
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -12,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
|||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.source import DataSourceOauthBinding
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -48,27 +47,36 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
credential_id = data_source_info.get("credential_id")
|
||||
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
sa.and_(
|
||||
DataSourceOauthBinding.tenant_id == document.tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
|
||||
if not credential:
|
||||
logger.error(
|
||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||
document_id,
|
||||
document.tenant_id,
|
||||
credential_id,
|
||||
)
|
||||
document.indexing_status = "error"
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=document.tenant_id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
|
@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
|
||||
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
|||
ssl_verify=True,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = Executor(
|
||||
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
|
||||
)
|
||||
|
||||
# Get assembled headers
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
# When api_key is empty, the custom header should NOT be set
|
||||
assert "X-Custom-Auth" not in headers
|
||||
# Create executor should raise AuthorizationConfigError
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
|
|
@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
|||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||
"""
|
||||
Test that custom authorization doesn't set header when api_key is empty.
|
||||
This test verifies the fix for issue #23554.
|
||||
Test that custom authorization raises error when api_key is empty.
|
||||
This test verifies the fix for issue #21830.
|
||||
"""
|
||||
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
|
|
@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
|||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# Custom header should NOT be set when api_key is empty
|
||||
assert "X-Custom-Auth:" not in data
|
||||
# Should fail with AuthorizationConfigError
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "API key is required" in result.error
|
||||
assert result.error_type == "AuthorizationConfigError"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from models import Account, Tenant
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
|
@ -298,7 +300,7 @@ class TestApiToolManageService:
|
|||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
|
@ -364,7 +366,7 @@ class TestApiToolManageService:
|
|||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
|
@ -428,21 +430,10 @@ class TestApiToolManageService:
|
|||
labels = ["test"]
|
||||
|
||||
# Act & Assert: Try to create provider with invalid schema type
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TypeAdapter(ApiProviderSchemaType).validate_python(schema_type)
|
||||
|
||||
assert "invalid schema type" in str(exc_info.value)
|
||||
assert "validation error" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
|
|
@ -464,7 +455,7 @@ class TestApiToolManageService:
|
|||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {} # Missing auth_type
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
|
@ -507,7 +498,7 @@ class TestApiToolManageService:
|
|||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔑"}
|
||||
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
|
|
|||
|
|
@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
|
|||
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
||||
]
|
||||
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
# Setup mock client
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.return_value = mock_tools
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
result = service._reconnect_provider(
|
||||
result = MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
|
|||
assert tools_data[1]["name"] == "test_tool_2"
|
||||
|
||||
# Verify mock interactions
|
||||
provider_entity = mcp_provider.to_entity()
|
||||
mock_mcp_client.assert_called_once()
|
||||
mock_mcp_client.assert_called_once_with(
|
||||
server_url="https://example.com/mcp",
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
|
|
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
|
|||
)
|
||||
|
||||
# Mock MCPClient to raise authentication error
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
from core.mcp.error import MCPAuthError
|
||||
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
result = service._reconnect_provider(
|
||||
result = MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
|
|||
)
|
||||
|
||||
# Mock MCPClient to raise connection error
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
from core.mcp.error import MCPError
|
||||
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
||||
service._reconnect_provider(
|
||||
MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,420 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueuePingEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
|
||||
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock(spec=ChatAppGenerateEntity)
|
||||
entity.task_id = "test-task-id"
|
||||
entity.app_id = "test-app-id"
|
||||
# minimal app_config used by pipeline internals
|
||||
entity.app_config = SimpleNamespace(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.CHAT,
|
||||
app_model_config_dict={},
|
||||
additional_features=None,
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
# minimal model_conf for LLMResult init
|
||||
entity.model_conf = SimpleNamespace(
|
||||
model="test-model",
|
||||
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
|
||||
credentials={},
|
||||
)
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = Mock(spec=AppQueueManager)
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_cycle_manager(self):
|
||||
"""Create a mock message cycle manager."""
|
||||
manager = Mock()
|
||||
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
|
||||
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
|
||||
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
|
||||
manager.handle_retriever_resources = Mock()
|
||||
manager.handle_annotation_reply.return_value = None
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation(self):
|
||||
"""Create a mock conversation."""
|
||||
conversation = Mock()
|
||||
conversation.id = "test-conversation-id"
|
||||
conversation.mode = "chat"
|
||||
return conversation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message."""
|
||||
message = Mock()
|
||||
message.id = "test-message-id"
|
||||
message.created_at = Mock()
|
||||
message.created_at.timestamp.return_value = 1234567890
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_state(self):
|
||||
"""Create a mock task state."""
|
||||
task_state = Mock(spec=EasyUITaskState)
|
||||
|
||||
# Create LLM result mock
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.prompt_messages = []
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = ""
|
||||
|
||||
task_state.llm_result = llm_result
|
||||
task_state.answer = ""
|
||||
|
||||
return task_state
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(
|
||||
self,
|
||||
mock_application_generate_entity,
|
||||
mock_queue_manager,
|
||||
mock_conversation,
|
||||
mock_message,
|
||||
mock_message_cycle_manager,
|
||||
mock_task_state,
|
||||
):
|
||||
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
|
||||
):
|
||||
pipeline = EasyUIBasedGenerateTaskPipeline(
|
||||
application_generate_entity=mock_application_generate_entity,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
stream=True,
|
||||
)
|
||||
pipeline._message_cycle_manager = mock_message_cycle_manager
|
||||
pipeline._task_state = mock_task_state
|
||||
return pipeline
|
||||
|
||||
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
|
||||
self, pipeline, mock_message_cycle_manager
|
||||
):
|
||||
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
|
||||
# Setup a minimal LLM chunk event
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "hi"
|
||||
chunk.prompt_messages = []
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
|
||||
|
||||
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with text content."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Hello, world!"
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello, world!"
|
||||
|
||||
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with list content."""
|
||||
# Setup
|
||||
text_content = Mock(spec=TextPromptMessageContent)
|
||||
text_content.data = "Hello"
|
||||
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = [text_content, " world!"]
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of agent message events."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Agent response"
|
||||
|
||||
agent_message_event = Mock(spec=QueueAgentMessageEvent)
|
||||
agent_message_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = agent_message_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Ensure method under assertion is a mock to track calls
|
||||
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
# Agent messages should use _agent_message_to_stream_response
|
||||
pipeline._agent_message_to_stream_response.assert_called_once_with(
|
||||
answer="Agent response", message_id="test-message-id"
|
||||
)
|
||||
|
||||
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of message end events."""
|
||||
# Setup
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = "Final response"
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = llm_result
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert mock_task_state.llm_result == llm_result
|
||||
pipeline._save_message.assert_called_once()
|
||||
pipeline._message_end_to_stream_response.assert_called_once()
|
||||
|
||||
def test_error_event(self, pipeline):
|
||||
"""Test handling of error events."""
|
||||
# Setup
|
||||
error_event = Mock(spec=QueueErrorEvent)
|
||||
error_event.error = Exception("Test error")
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = error_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.handle_error = Mock(return_value=Exception("Test error"))
|
||||
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.handle_error.assert_called_once()
|
||||
pipeline.error_to_stream_response.assert_called_once()
|
||||
|
||||
def test_ping_event(self, pipeline):
|
||||
"""Test handling of ping events."""
|
||||
# Setup
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.ping_stream_response.assert_called_once()
|
||||
|
||||
def test_file_event(self, pipeline, mock_message_cycle_manager):
|
||||
"""Test handling of file events."""
|
||||
# Setup
|
||||
file_event = Mock(spec=QueueMessageFileEvent)
|
||||
file_event.message_file_id = "file-id"
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = file_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
file_response = Mock(spec=MessageFileStreamResponse)
|
||||
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert responses[0] == file_response
|
||||
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
|
||||
|
||||
def test_publisher_is_called_with_messages(self, pipeline):
|
||||
"""Test that publisher publishes messages when provided."""
|
||||
# Setup
|
||||
publisher = Mock(spec=AppGeneratorTTSPublisher)
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
# Called once with message and once with None at the end
|
||||
assert publisher.publish.call_count == 2
|
||||
publisher.publish.assert_any_call(mock_queue_message)
|
||||
publisher.publish.assert_any_call(None)
|
||||
|
||||
def test_trace_manager_passed_to_save_message(self, pipeline):
|
||||
"""Test that trace manager is passed to _save_message."""
|
||||
# Setup
|
||||
trace_manager = Mock(spec=TraceQueueManager)
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = None
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
|
||||
|
||||
# Assert
|
||||
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
|
||||
|
||||
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling multiple events in sequence."""
|
||||
# Setup
|
||||
chunk1 = Mock()
|
||||
chunk1.delta.message.content = "Hello"
|
||||
chunk1.prompt_messages = []
|
||||
|
||||
chunk2 = Mock()
|
||||
chunk2.delta.message.content = " world!"
|
||||
chunk2.prompt_messages = []
|
||||
|
||||
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event1.chunk = chunk1
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event2.chunk = chunk2
|
||||
|
||||
mock_queue_messages = [
|
||||
Mock(event=llm_chunk_event1),
|
||||
Mock(event=ping_event),
|
||||
Mock(event=llm_chunk_event2),
|
||||
]
|
||||
pipeline.queue_manager.listen.return_value = mock_queue_messages
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 3
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
# Verify calls to message_to_stream_response
|
||||
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
|
||||
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
|
||||
|
||||
class TestMessageCycleManagerOptimization:
|
||||
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock()
|
||||
entity.task_id = "test-task-id"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def message_cycle_manager(self, mock_application_generate_entity):
|
||||
"""Create a message cycle manager instance."""
|
||||
task_state = Mock()
|
||||
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
|
||||
|
||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and no message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute: compute event type once, then pass to message_to_stream_response
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
# Execute with event_type provided
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
# Should not query database when event_type is provided
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world",
|
||||
message_id="test-message-id",
|
||||
from_variable_selector=["var1", "var2"],
|
||||
event_type=StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.from_variable_selector == ["var1", "var2"]
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
|
||||
def test_optimization_usage_example(self, message_cycle_manager):
|
||||
"""Test the optimization pattern that should be used by callers."""
|
||||
# Step 1: Get event type once (this queries database)
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None # No files
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Should query database once
|
||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
||||
assert event_type == StreamEvent.MESSAGE
|
||||
|
||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
||||
|
||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
chunk2_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Should not query database again
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
assert chunk1_response.answer == "Chunk 1"
|
||||
assert chunk2_response.answer == "Chunk 2"
|
||||
|
|
@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch):
|
|||
# DB interactions should be recorded
|
||||
assert len(db_stub.session.added) == 2
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
|
||||
def test_extract_images_from_docx_uses_internal_files_url():
|
||||
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
|
||||
# Test the URL generation logic directly
|
||||
from configs import dify_config
|
||||
|
||||
# Mock the configuration values
|
||||
original_files_url = getattr(dify_config, "FILES_URL", None)
|
||||
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
|
||||
|
||||
try:
|
||||
# Set both URLs - INTERNAL should take precedence
|
||||
dify_config.FILES_URL = "http://external.example.com"
|
||||
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
|
||||
|
||||
# Test the URL generation logic (same as in word_extractor.py)
|
||||
upload_file_id = "test_file_id"
|
||||
|
||||
# This is the pattern we fixed in the word extractor
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
|
||||
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
|
||||
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
if original_files_url is not None:
|
||||
dify_config.FILES_URL = original_files_url
|
||||
if original_internal_files_url is not None:
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
|
|
|||
|
|
@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter:
|
|||
# Verify no empty chunks
|
||||
assert all(len(chunk) > 0 for chunk in result)
|
||||
|
||||
def test_double_slash_n(self):
|
||||
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
|
||||
separator = "\\n\\n---\\n\\n"
|
||||
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
|
||||
chunks = splitter.split_text(data)
|
||||
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Metadata Preservation
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import pytest
|
||||
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
|
|
@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import (
|
|||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -348,3 +351,127 @@ def test_init_params():
|
|||
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_bearer():
|
||||
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": ""},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_basic():
|
||||
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "basic", "api_key": ""},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_custom():
|
||||
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_whitespace_only_api_key_raises_error():
|
||||
"""Test that whitespace-only API key raises AuthorizationConfigError."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": " "},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_valid_api_key_works():
|
||||
"""Test that valid API key works correctly for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": "valid-api-key-123"},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Should not raise an error
|
||||
headers = executor._assembling_headers()
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == "Bearer valid-api-key-123"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
|
@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables):
|
|||
|
||||
|
||||
def test_json_object_valid_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
|
|
@ -65,7 +68,7 @@ def test_json_object_valid_schema():
|
|||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
result = node._run()
|
||||
|
|
@ -74,12 +77,23 @@ def test_json_object_valid_schema():
|
|||
|
||||
|
||||
def test_json_object_invalid_json_string():
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
json_schema=schema,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -88,38 +102,21 @@ def test_json_object_invalid_json_string():
|
|||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
|
||||
def test_json_object_valid_json_but_not_object(value):
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": value}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_does_not_match_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
|
|
@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema():
|
|||
]
|
||||
|
||||
# age is a string, which violates the schema (expects number)
|
||||
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
|
|
@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema():
|
|||
|
||||
|
||||
def test_json_object_missing_required_schema_field():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
|
|
@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field():
|
|||
]
|
||||
|
||||
# Missing required field "name"
|
||||
user_inputs = {"profile": {"age": 20}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
|
|
@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided():
|
|||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=False,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided():
|
|||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
# Current implementation raises a validation error even when the variable is optional
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
with pytest.raises(ValueError, match="profile is required in input form"):
|
||||
node._run()
|
||||
|
|
|
|||
|
|
@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases:
|
|||
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestBillingServiceSubscriptionOperations:
|
||||
"""Unit tests for subscription operations in BillingService.
|
||||
|
||||
Tests cover:
|
||||
- Bulk plan retrieval with chunking
|
||||
- Expired subscription cleanup whitelist retrieval
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
"""Mock _send_request method."""
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
def test_get_plan_bulk_with_empty_list(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with empty tenant list."""
|
||||
# Arrange
|
||||
tenant_ids = []
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
mock_send_request.assert_not_called()
|
||||
|
||||
def test_get_plan_bulk_with_chunking(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with more than 200 tenants (chunking logic)."""
|
||||
# Arrange - 250 tenants to test chunking (chunk_size = 200)
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# First chunk: tenants 0-199
|
||||
first_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Second chunk: tenants 200-249
|
||||
second_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)}
|
||||
}
|
||||
|
||||
mock_send_request.side_effect = [first_chunk_response, second_chunk_response]
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 250
|
||||
assert result["tenant-0"]["plan"] == "sandbox"
|
||||
assert result["tenant-199"]["plan"] == "sandbox"
|
||||
assert result["tenant-200"]["plan"] == "professional"
|
||||
assert result["tenant-249"]["plan"] == "professional"
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
# Verify first chunk call
|
||||
first_call = mock_send_request.call_args_list[0]
|
||||
assert first_call[0][0] == "POST"
|
||||
assert first_call[0][1] == "/subscription/plan/batch"
|
||||
assert len(first_call[1]["json"]["tenant_ids"]) == 200
|
||||
|
||||
# Verify second chunk call
|
||||
second_call = mock_send_request.call_args_list[1]
|
||||
assert len(second_call[1]["json"]["tenant_ids"]) == 50
|
||||
|
||||
def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when one batch fails but others succeed."""
|
||||
# Arrange - 250 tenants, second batch will fail
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# First chunk succeeds
|
||||
first_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Second chunk fails - need to create a mock that raises when called
|
||||
def side_effect_func(*args, **kwargs):
|
||||
if mock_send_request.call_count == 1:
|
||||
return first_chunk_response
|
||||
else:
|
||||
raise ValueError("API error")
|
||||
|
||||
mock_send_request.side_effect = side_effect_func
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert - should only have data from first batch
|
||||
assert len(result) == 200
|
||||
assert result["tenant-0"]["plan"] == "sandbox"
|
||||
assert result["tenant-199"]["plan"] == "sandbox"
|
||||
assert "tenant-200" not in result
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when all batches fail."""
|
||||
# Arrange
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# All chunks fail
|
||||
def side_effect_func(*args, **kwargs):
|
||||
raise ValueError("API error")
|
||||
|
||||
mock_send_request.side_effect = side_effect_func
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert - should return empty dict
|
||||
assert result == {}
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with exactly 200 tenants (boundary condition)."""
|
||||
# Arrange
|
||||
tenant_ids = [f"tenant-{i}" for i in range(200)]
|
||||
mock_send_request.return_value = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 200
|
||||
assert mock_send_request.call_count == 1
|
||||
|
||||
def test_get_plan_bulk_with_empty_data_response(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with empty data in response."""
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
mock_send_request.return_value = {"data": {}}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request):
|
||||
"""Test successful retrieval of expired subscription cleanup whitelist."""
|
||||
# Arrange
|
||||
api_response = [
|
||||
{
|
||||
"created_at": "2025-10-16T01:56:17",
|
||||
"tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6",
|
||||
"contact": "example@dify.ai",
|
||||
"id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5",
|
||||
"expired_at": "2026-01-01T01:56:17",
|
||||
"updated_at": "2025-10-16T01:56:17",
|
||||
},
|
||||
{
|
||||
"created_at": "2025-10-16T02:00:00",
|
||||
"tenant_id": "tenant-2",
|
||||
"contact": "test@example.com",
|
||||
"id": "whitelist-id-2",
|
||||
"expired_at": "2026-02-01T00:00:00",
|
||||
"updated_at": "2025-10-16T02:00:00",
|
||||
},
|
||||
{
|
||||
"created_at": "2025-10-16T03:00:00",
|
||||
"tenant_id": "tenant-3",
|
||||
"contact": "another@example.com",
|
||||
"id": "whitelist-id-3",
|
||||
"expired_at": "2026-03-01T00:00:00",
|
||||
"updated_at": "2025-10-16T03:00:00",
|
||||
},
|
||||
]
|
||||
mock_send_request.return_value = {"data": api_response}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
|
||||
# Assert - should return only tenant_ids
|
||||
assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"]
|
||||
assert len(result) == 3
|
||||
assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6"
|
||||
assert result[1] == "tenant-2"
|
||||
assert result[2] == "tenant-3"
|
||||
mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist")
|
||||
|
||||
def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request):
|
||||
"""Test retrieval of empty cleanup whitelist."""
|
||||
# Arrange
|
||||
mock_send_request.return_value = {"data": []}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestBillingServiceIntegrationScenarios:
|
||||
"""Integration-style tests simulating real-world usage scenarios.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,520 @@
|
|||
"""
|
||||
Unit tests for document indexing sync task.
|
||||
|
||||
This module tests the document indexing sync task functionality including:
|
||||
- Syncing Notion documents when updated
|
||||
- Validating document and data source existence
|
||||
- Credential validation and retrieval
|
||||
- Cleaning old segments before re-indexing
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
"""Generate a unique tenant ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
"""Generate a unique dataset ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_id():
|
||||
"""Generate a unique document ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_workspace_id():
|
||||
"""Generate a Notion workspace ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_page_id():
|
||||
"""Generate a Notion page ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credential_id():
|
||||
"""Generate a credential ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id, tenant_id):
|
||||
"""Create a mock Dataset object."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
|
||||
"""Create a mock Document object with Notion data source."""
|
||||
doc = Mock(spec=Document)
|
||||
doc.id = document_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.tenant_id = tenant_id
|
||||
doc.data_source_type = "notion_import"
|
||||
doc.indexing_status = "completed"
|
||||
doc.error = None
|
||||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
doc.doc_form = "text_model"
|
||||
doc.data_source_info_dict = {
|
||||
"notion_workspace_id": notion_workspace_id,
|
||||
"notion_page_id": notion_page_id,
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
"credential_id": credential_id,
|
||||
}
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_segments(document_id):
|
||||
"""Create mock DocumentSegment objects."""
|
||||
segments = []
|
||||
for i in range(3):
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = document_id
|
||||
segment.index_node_id = f"node-{document_id}-{i}"
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasource_provider_service():
|
||||
"""Mock DatasourceProviderService."""
|
||||
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||
mock_service_class.return_value = mock_service
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_notion_extractor():
|
||||
"""Mock NotionExtractor."""
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
yield mock_extractor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
yield mock_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner.run = Mock()
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for document_indexing_sync_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTask:
|
||||
"""Tests for the document_indexing_sync_task function."""
|
||||
|
||||
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
|
||||
"""Test that task handles document not found gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_page_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when data_source_info is empty."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = None
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_credential_not_found(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles missing credentials by updating document status."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
assert mock_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in mock_document.error
|
||||
assert mock_document.stopped_at is not None
|
||||
mock_db_session.commit.assert_called()
|
||||
mock_db_session.close.assert_called()
|
||||
|
||||
def test_page_not_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task does nothing when page has not been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
# Return same time as stored in document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# No session operations should be performed beyond the initial query
|
||||
mock_db_session.close.assert_not_called()
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test successful sync flow when Notion page has been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# NotionExtractor returns updated time
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Verify document status was updated to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
assert mock_document.processing_started_at is not None
|
||||
|
||||
# Verify segments were cleaned
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
# Verify session operations
|
||||
assert mock_db_session.commit.called
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_dataset_not_found_during_cleaning(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles dataset not found during cleaning phase."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document should still be set to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that indexing continues even if cleaning fails."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Indexing should still be attempted despite cleaning error
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_indexing_runner_document_paused_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after handling error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_indexing_runner_general_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that general exceptions during indexing are handled."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_notion_extractor_initialized_with_correct_params(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
notion_workspace_id,
|
||||
notion_page_id,
|
||||
):
|
||||
"""Test that NotionExtractor is initialized with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
|
||||
|
||||
# Act
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_extractor_class.assert_called_once_with(
|
||||
notion_workspace_id=notion_workspace_id,
|
||||
notion_obj_id=notion_page_id,
|
||||
notion_page_type="page",
|
||||
notion_access_token="test_token",
|
||||
tenant_id=mock_document.tenant_id,
|
||||
)
|
||||
|
||||
def test_datasource_credentials_requested_correctly(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
credential_id,
|
||||
):
|
||||
"""Test that datasource credentials are requested with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=mock_document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
def test_credential_id_missing_uses_none(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles missing credential_id by passing None."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {
|
||||
"notion_workspace_id": "ws123",
|
||||
"notion_page_id": "page123",
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=mock_document.tenant_id,
|
||||
credential_id=None,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
def test_index_processor_clean_called_with_correct_params(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that index processor clean is called with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
|
||||
mock_processor.clean.assert_called_once_with(
|
||||
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
|
||||
)
|
||||
|
|
@ -1380,7 +1380,7 @@ dependencies = [
|
|||
{ name = "bs4" },
|
||||
{ name = "cachetools" },
|
||||
{ name = "celery" },
|
||||
{ name = "chardet" },
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "croniter" },
|
||||
{ name = "flask" },
|
||||
{ name = "flask-compress" },
|
||||
|
|
@ -1403,6 +1403,7 @@ dependencies = [
|
|||
{ name = "httpx-sse" },
|
||||
{ name = "jieba" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "langfuse" },
|
||||
{ name = "langsmith" },
|
||||
{ name = "litellm" },
|
||||
|
|
@ -1577,7 +1578,7 @@ requires-dist = [
|
|||
{ name = "bs4", specifier = "~=0.0.1" },
|
||||
{ name = "cachetools", specifier = "~=5.3.0" },
|
||||
{ name = "celery", specifier = "~=5.5.2" },
|
||||
{ name = "chardet", specifier = "~=5.1.0" },
|
||||
{ name = "charset-normalizer", specifier = ">=3.4.4" },
|
||||
{ name = "croniter", specifier = ">=6.0.0" },
|
||||
{ name = "flask", specifier = "~=3.1.2" },
|
||||
{ name = "flask-compress", specifier = ">=1.17,<1.18" },
|
||||
|
|
@ -1600,6 +1601,7 @@ requires-dist = [
|
|||
{ name = "httpx-sse", specifier = "~=0.4.0" },
|
||||
{ name = "jieba", specifier = "==0.42.1" },
|
||||
{ name = "json-repair", specifier = ">=0.41.1" },
|
||||
{ name = "jsonschema", specifier = ">=4.25.1" },
|
||||
{ name = "langfuse", specifier = "~=2.51.3" },
|
||||
{ name = "langsmith", specifier = "~=0.1.77" },
|
||||
{ name = "litellm", specifier = "==1.77.1" },
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ show_help() {
|
|||
echo " pipeline - Standard pipeline tasks"
|
||||
echo " triggered_workflow_dispatcher - Trigger dispatcher tasks"
|
||||
echo " trigger_refresh_executor - Trigger refresh tasks"
|
||||
echo " retention - Retention tasks"
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
|
|
@ -105,10 +106,10 @@ if [[ -z "${QUEUES}" ]]; then
|
|||
# Configure queues based on edition
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
fi
|
||||
|
||||
echo "No queues specified, using edition-based defaults: ${QUEUES}"
|
||||
|
|
|
|||
|
|
@ -1369,7 +1369,10 @@ PLUGIN_STDIO_BUFFER_SIZE=1024
|
|||
PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880
|
||||
|
||||
PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120
|
||||
# Plugin Daemon side timeout (configure to match the API side below)
|
||||
PLUGIN_MAX_EXECUTION_TIMEOUT=600
|
||||
# API side timeout (configure to match the Plugin Daemon side above)
|
||||
PLUGIN_DAEMON_TIMEOUT=600.0
|
||||
# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
PIP_MIRROR_URL=
|
||||
|
||||
|
|
@ -1479,4 +1482,9 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
|||
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
|
||||
# The API key of amplitude
|
||||
AMPLITUDE_API_KEY=
|
||||
AMPLITUDE_API_KEY=
|
||||
|
||||
# Sandbox expired records clean configuration
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ services:
|
|||
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
|
||||
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
|
||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||
depends_on:
|
||||
init_permissions:
|
||||
|
|
|
|||
|
|
@ -591,6 +591,7 @@ x-shared-env: &shared-api-worker-env
|
|||
PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880}
|
||||
PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120}
|
||||
PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||
PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local}
|
||||
PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage}
|
||||
|
|
@ -663,6 +664,9 @@ x-shared-env: &shared-api-worker-env
|
|||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20}
|
||||
ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5}
|
||||
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21}
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000}
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30}
|
||||
|
||||
services:
|
||||
# Init container to fix permissions
|
||||
|
|
@ -699,6 +703,7 @@ services:
|
|||
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
|
||||
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
|
||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||
depends_on:
|
||||
init_permissions:
|
||||
|
|
|
|||
|
|
@ -61,14 +61,14 @@
|
|||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:
|
||||
Dify est une plateforme de développement d'applications LLM open source. Sa interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:
|
||||
</br> </br>
|
||||
|
||||
**1. Flux de travail** :
|
||||
Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore.
|
||||
|
||||
**2. Prise en charge complète des modèles** :
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source offerts par dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
|
@ -79,7 +79,7 @@ Interface intuitive pour créer des prompts, comparer les performances des modè
|
|||
Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants.
|
||||
|
||||
**5. Capacités d'agent** :
|
||||
Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha.
|
||||
Vous pouvez définir des agents basés sur l'appel de fonctions LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha.
|
||||
|
||||
**6. LLMOps** :
|
||||
Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
{
|
||||
"recommendations": [
|
||||
"bradlc.vscode-tailwindcss",
|
||||
"firsttris.vscode-jest-runner",
|
||||
"kisstkondoros.vscode-codemetrics"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -99,14 +99,14 @@ If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscod
|
|||
|
||||
## Test
|
||||
|
||||
We use [Jest](https://jestjs.io/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing.
|
||||
We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing.
|
||||
|
||||
**📖 Complete Testing Guide**: See [web/testing/testing.md](./testing/testing.md) for detailed testing specifications, best practices, and examples.
|
||||
|
||||
Run test:
|
||||
|
||||
```bash
|
||||
pnpm run test
|
||||
pnpm test
|
||||
```
|
||||
|
||||
### Example Code
|
||||
|
|
|
|||
|
|
@ -1,9 +1,41 @@
|
|||
import { merge, noop } from 'lodash-es'
|
||||
import { defaultPlan } from '@/app/components/billing/config'
|
||||
import { baseProviderContextValue } from '@/context/provider-context'
|
||||
import type { ProviderContextState } from '@/context/provider-context'
|
||||
import type { Plan, UsagePlanInfo } from '@/app/components/billing/type'
|
||||
|
||||
// Avoid being mocked in tests
|
||||
export const baseProviderContextValue: ProviderContextState = {
|
||||
modelProviders: [],
|
||||
refreshModelProviders: noop,
|
||||
textGenerationModelList: [],
|
||||
supportRetrievalMethods: [],
|
||||
isAPIKeySet: true,
|
||||
plan: defaultPlan,
|
||||
isFetchedPlan: false,
|
||||
enableBilling: false,
|
||||
onPlanInfoChanged: noop,
|
||||
enableReplaceWebAppLogo: false,
|
||||
modelLoadBalancingEnabled: false,
|
||||
datasetOperatorEnabled: false,
|
||||
enableEducationPlan: false,
|
||||
isEducationWorkspace: false,
|
||||
isEducationAccount: false,
|
||||
allowRefreshEducationVerify: false,
|
||||
educationAccountExpireAt: null,
|
||||
isLoadingEducationAccountInfo: false,
|
||||
isFetchingEducationAccountInfo: false,
|
||||
webappCopyrightEnabled: false,
|
||||
licenseLimit: {
|
||||
workspace_members: {
|
||||
size: 0,
|
||||
limit: 0,
|
||||
},
|
||||
},
|
||||
refreshLicenseLimit: noop,
|
||||
isAllowTransferWorkspace: false,
|
||||
isAllowPublishAsCustomKnowledgePipelineTemplate: false,
|
||||
}
|
||||
|
||||
export const createMockProviderContextValue = (overrides: Partial<ProviderContextState> = {}): ProviderContextState => {
|
||||
const merged = merge({}, baseProviderContextValue, overrides)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,40 +0,0 @@
|
|||
/**
|
||||
* Shared mock for react-i18next
|
||||
*
|
||||
* Jest automatically uses this mock when react-i18next is imported in tests.
|
||||
* The default behavior returns the translation key as-is, which is suitable
|
||||
* for most test scenarios.
|
||||
*
|
||||
* For tests that need custom translations, you can override with jest.mock():
|
||||
*
|
||||
* @example
|
||||
* jest.mock('react-i18next', () => ({
|
||||
* useTranslation: () => ({
|
||||
* t: (key: string) => {
|
||||
* if (key === 'some.key') return 'Custom translation'
|
||||
* return key
|
||||
* },
|
||||
* }),
|
||||
* }))
|
||||
*/
|
||||
|
||||
export const useTranslation = () => ({
|
||||
t: (key: string, options?: Record<string, unknown>) => {
|
||||
if (options?.returnObjects)
|
||||
return [`${key}-feature-1`, `${key}-feature-2`]
|
||||
if (options)
|
||||
return `${key}:${JSON.stringify(options)}`
|
||||
return key
|
||||
},
|
||||
i18n: {
|
||||
language: 'en',
|
||||
changeLanguage: jest.fn(),
|
||||
},
|
||||
})
|
||||
|
||||
export const Trans = ({ children }: { children?: React.ReactNode }) => children
|
||||
|
||||
export const initReactI18next = {
|
||||
type: '3rdParty',
|
||||
init: jest.fn(),
|
||||
}
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import type { Mock } from 'vitest'
|
||||
/**
|
||||
* Document Detail Navigation Fix Verification Test
|
||||
*
|
||||
|
|
@ -10,32 +11,32 @@ import { useRouter } from 'next/navigation'
|
|||
import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document'
|
||||
|
||||
// Mock Next.js router
|
||||
const mockPush = jest.fn()
|
||||
jest.mock('next/navigation', () => ({
|
||||
useRouter: jest.fn(() => ({
|
||||
const mockPush = vi.fn()
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: vi.fn(() => ({
|
||||
push: mockPush,
|
||||
})),
|
||||
}))
|
||||
|
||||
// Mock the document service hooks
|
||||
jest.mock('@/service/knowledge/use-document', () => ({
|
||||
useDocumentDetail: jest.fn(),
|
||||
useDocumentMetadata: jest.fn(),
|
||||
useInvalidDocumentList: jest.fn(() => jest.fn()),
|
||||
vi.mock('@/service/knowledge/use-document', () => ({
|
||||
useDocumentDetail: vi.fn(),
|
||||
useDocumentMetadata: vi.fn(),
|
||||
useInvalidDocumentList: vi.fn(() => vi.fn()),
|
||||
}))
|
||||
|
||||
// Mock other dependencies
|
||||
jest.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContext: jest.fn(() => [null]),
|
||||
vi.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContext: vi.fn(() => [null]),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/use-base', () => ({
|
||||
useInvalid: jest.fn(() => jest.fn()),
|
||||
vi.mock('@/service/use-base', () => ({
|
||||
useInvalid: vi.fn(() => vi.fn()),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/knowledge/use-segment', () => ({
|
||||
useSegmentListKey: jest.fn(),
|
||||
useChildSegmentListKey: jest.fn(),
|
||||
vi.mock('@/service/knowledge/use-segment', () => ({
|
||||
useSegmentListKey: vi.fn(),
|
||||
useChildSegmentListKey: vi.fn(),
|
||||
}))
|
||||
|
||||
// Create a minimal version of the DocumentDetail component that includes our fix
|
||||
|
|
@ -66,10 +67,10 @@ const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; d
|
|||
|
||||
describe('Document Detail Navigation Fix Verification', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Mock successful API responses
|
||||
;(useDocumentDetail as jest.Mock).mockReturnValue({
|
||||
;(useDocumentDetail as Mock).mockReturnValue({
|
||||
data: {
|
||||
id: 'doc-123',
|
||||
name: 'Test Document',
|
||||
|
|
@ -80,7 +81,7 @@ describe('Document Detail Navigation Fix Verification', () => {
|
|||
error: null,
|
||||
})
|
||||
|
||||
;(useDocumentMetadata as jest.Mock).mockReturnValue({
|
||||
;(useDocumentMetadata as Mock).mockReturnValue({
|
||||
data: null,
|
||||
error: null,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -4,16 +4,17 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
|||
import MailAndPasswordAuth from '@/app/(shareLayout)/webapp-signin/components/mail-and-password-auth'
|
||||
import CheckCode from '@/app/(shareLayout)/webapp-signin/check-code/page'
|
||||
|
||||
const replaceMock = jest.fn()
|
||||
const backMock = jest.fn()
|
||||
const replaceMock = vi.fn()
|
||||
const backMock = vi.fn()
|
||||
const useSearchParamsMock = vi.fn(() => new URLSearchParams())
|
||||
|
||||
jest.mock('next/navigation', () => ({
|
||||
usePathname: jest.fn(() => '/chatbot/test-app'),
|
||||
useRouter: jest.fn(() => ({
|
||||
vi.mock('next/navigation', () => ({
|
||||
usePathname: vi.fn(() => '/chatbot/test-app'),
|
||||
useRouter: vi.fn(() => ({
|
||||
replace: replaceMock,
|
||||
back: backMock,
|
||||
})),
|
||||
useSearchParams: jest.fn(),
|
||||
useSearchParams: () => useSearchParamsMock(),
|
||||
}))
|
||||
|
||||
const mockStoreState = {
|
||||
|
|
@ -21,59 +22,55 @@ const mockStoreState = {
|
|||
shareCode: 'test-app',
|
||||
}
|
||||
|
||||
const useWebAppStoreMock = jest.fn((selector?: (state: typeof mockStoreState) => any) => {
|
||||
const useWebAppStoreMock = vi.fn((selector?: (state: typeof mockStoreState) => any) => {
|
||||
return selector ? selector(mockStoreState) : mockStoreState
|
||||
})
|
||||
|
||||
jest.mock('@/context/web-app-context', () => ({
|
||||
vi.mock('@/context/web-app-context', () => ({
|
||||
useWebAppStore: (selector?: (state: typeof mockStoreState) => any) => useWebAppStoreMock(selector),
|
||||
}))
|
||||
|
||||
const webAppLoginMock = jest.fn()
|
||||
const webAppEmailLoginWithCodeMock = jest.fn()
|
||||
const sendWebAppEMailLoginCodeMock = jest.fn()
|
||||
const webAppLoginMock = vi.fn()
|
||||
const webAppEmailLoginWithCodeMock = vi.fn()
|
||||
const sendWebAppEMailLoginCodeMock = vi.fn()
|
||||
|
||||
jest.mock('@/service/common', () => ({
|
||||
vi.mock('@/service/common', () => ({
|
||||
webAppLogin: (...args: any[]) => webAppLoginMock(...args),
|
||||
webAppEmailLoginWithCode: (...args: any[]) => webAppEmailLoginWithCodeMock(...args),
|
||||
sendWebAppEMailLoginCode: (...args: any[]) => sendWebAppEMailLoginCodeMock(...args),
|
||||
}))
|
||||
|
||||
const fetchAccessTokenMock = jest.fn()
|
||||
const fetchAccessTokenMock = vi.fn()
|
||||
|
||||
jest.mock('@/service/share', () => ({
|
||||
vi.mock('@/service/share', () => ({
|
||||
fetchAccessToken: (...args: any[]) => fetchAccessTokenMock(...args),
|
||||
}))
|
||||
|
||||
const setWebAppAccessTokenMock = jest.fn()
|
||||
const setWebAppPassportMock = jest.fn()
|
||||
const setWebAppAccessTokenMock = vi.fn()
|
||||
const setWebAppPassportMock = vi.fn()
|
||||
|
||||
jest.mock('@/service/webapp-auth', () => ({
|
||||
vi.mock('@/service/webapp-auth', () => ({
|
||||
setWebAppAccessToken: (...args: any[]) => setWebAppAccessTokenMock(...args),
|
||||
setWebAppPassport: (...args: any[]) => setWebAppPassportMock(...args),
|
||||
webAppLogout: jest.fn(),
|
||||
webAppLogout: vi.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/signin/countdown', () => () => <div data-testid="countdown" />)
|
||||
vi.mock('@/app/components/signin/countdown', () => ({ default: () => <div data-testid="countdown" /> }))
|
||||
|
||||
jest.mock('@remixicon/react', () => ({
|
||||
vi.mock('@remixicon/react', () => ({
|
||||
RiMailSendFill: () => <div data-testid="mail-icon" />,
|
||||
RiArrowLeftLine: () => <div data-testid="arrow-icon" />,
|
||||
}))
|
||||
|
||||
const { useSearchParams } = jest.requireMock('next/navigation') as {
|
||||
useSearchParams: jest.Mock
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('embedded user id propagation in authentication flows', () => {
|
||||
it('passes embedded user id when logging in with email and password', async () => {
|
||||
const params = new URLSearchParams()
|
||||
params.set('redirect_url', encodeURIComponent('/chatbot/test-app'))
|
||||
useSearchParams.mockReturnValue(params)
|
||||
useSearchParamsMock.mockReturnValue(params)
|
||||
|
||||
webAppLoginMock.mockResolvedValue({ result: 'success', data: { access_token: 'login-token' } })
|
||||
fetchAccessTokenMock.mockResolvedValue({ access_token: 'passport-token' })
|
||||
|
|
@ -100,7 +97,7 @@ describe('embedded user id propagation in authentication flows', () => {
|
|||
params.set('redirect_url', encodeURIComponent('/chatbot/test-app'))
|
||||
params.set('email', encodeURIComponent('user@example.com'))
|
||||
params.set('token', encodeURIComponent('token-abc'))
|
||||
useSearchParams.mockReturnValue(params)
|
||||
useSearchParamsMock.mockReturnValue(params)
|
||||
|
||||
webAppEmailLoginWithCodeMock.mockResolvedValue({ result: 'success', data: { access_token: 'code-token' } })
|
||||
fetchAccessTokenMock.mockResolvedValue({ access_token: 'passport-token' })
|
||||
|
|
|
|||
|
|
@ -1,42 +1,42 @@
|
|||
import React from 'react'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
|
||||
import WebAppStoreProvider, { useWebAppStore } from '@/context/web-app-context'
|
||||
|
||||
jest.mock('next/navigation', () => ({
|
||||
usePathname: jest.fn(() => '/chatbot/sample-app'),
|
||||
useSearchParams: jest.fn(() => {
|
||||
vi.mock('next/navigation', () => ({
|
||||
usePathname: vi.fn(() => '/chatbot/sample-app'),
|
||||
useSearchParams: vi.fn(() => {
|
||||
const params = new URLSearchParams()
|
||||
return params
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/use-share', () => {
|
||||
const { AccessMode } = jest.requireActual('@/models/access-control')
|
||||
return {
|
||||
useGetWebAppAccessModeByCode: jest.fn(() => ({
|
||||
isLoading: false,
|
||||
data: { accessMode: AccessMode.PUBLIC },
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
jest.mock('@/app/components/base/chat/utils', () => ({
|
||||
getProcessedSystemVariablesFromUrlParams: jest.fn(),
|
||||
vi.mock('@/service/use-share', () => ({
|
||||
useGetWebAppAccessModeByCode: vi.fn(() => ({
|
||||
isLoading: false,
|
||||
data: { accessMode: AccessMode.PUBLIC },
|
||||
})),
|
||||
}))
|
||||
|
||||
const { getProcessedSystemVariablesFromUrlParams: mockGetProcessedSystemVariablesFromUrlParams }
|
||||
= jest.requireMock('@/app/components/base/chat/utils') as {
|
||||
getProcessedSystemVariablesFromUrlParams: jest.Mock
|
||||
}
|
||||
// Store the mock implementation in a way that survives hoisting
|
||||
const mockGetProcessedSystemVariablesFromUrlParams = vi.fn()
|
||||
|
||||
jest.mock('@/context/global-public-context', () => {
|
||||
const mockGlobalStoreState = {
|
||||
vi.mock('@/app/components/base/chat/utils', () => ({
|
||||
getProcessedSystemVariablesFromUrlParams: (...args: any[]) => mockGetProcessedSystemVariablesFromUrlParams(...args),
|
||||
}))
|
||||
|
||||
// Use vi.hoisted to define mock state before vi.mock hoisting
|
||||
const { mockGlobalStoreState } = vi.hoisted(() => ({
|
||||
mockGlobalStoreState: {
|
||||
isGlobalPending: false,
|
||||
setIsGlobalPending: jest.fn(),
|
||||
setIsGlobalPending: vi.fn(),
|
||||
systemFeatures: {},
|
||||
setSystemFeatures: jest.fn(),
|
||||
}
|
||||
setSystemFeatures: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => {
|
||||
const useGlobalPublicStore = Object.assign(
|
||||
(selector?: (state: typeof mockGlobalStoreState) => any) =>
|
||||
selector ? selector(mockGlobalStoreState) : mockGlobalStoreState,
|
||||
|
|
@ -56,21 +56,6 @@ jest.mock('@/context/global-public-context', () => {
|
|||
}
|
||||
})
|
||||
|
||||
const {
|
||||
useGlobalPublicStore: useGlobalPublicStoreMock,
|
||||
} = jest.requireMock('@/context/global-public-context') as {
|
||||
useGlobalPublicStore: ((selector?: (state: any) => any) => any) & {
|
||||
setState: (updater: any) => void
|
||||
__mockState: {
|
||||
isGlobalPending: boolean
|
||||
setIsGlobalPending: jest.Mock
|
||||
systemFeatures: Record<string, unknown>
|
||||
setSystemFeatures: jest.Mock
|
||||
}
|
||||
}
|
||||
}
|
||||
const mockGlobalStoreState = useGlobalPublicStoreMock.__mockState
|
||||
|
||||
const TestConsumer = () => {
|
||||
const embeddedUserId = useWebAppStore(state => state.embeddedUserId)
|
||||
const embeddedConversationId = useWebAppStore(state => state.embeddedConversationId)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import React from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import '@testing-library/jest-dom'
|
||||
import CommandSelector from '../../app/components/goto-anything/command-selector'
|
||||
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
|
||||
|
||||
jest.mock('cmdk', () => ({
|
||||
vi.mock('cmdk', () => ({
|
||||
Command: {
|
||||
Group: ({ children, className }: any) => <div className={className}>{children}</div>,
|
||||
Item: ({ children, onSelect, value, className }: any) => (
|
||||
|
|
@ -27,36 +26,36 @@ describe('CommandSelector', () => {
|
|||
shortcut: '@app',
|
||||
title: 'Search Applications',
|
||||
description: 'Search apps',
|
||||
search: jest.fn(),
|
||||
search: vi.fn(),
|
||||
},
|
||||
knowledge: {
|
||||
key: '@knowledge',
|
||||
shortcut: '@kb',
|
||||
title: 'Search Knowledge',
|
||||
description: 'Search knowledge bases',
|
||||
search: jest.fn(),
|
||||
search: vi.fn(),
|
||||
},
|
||||
plugin: {
|
||||
key: '@plugin',
|
||||
shortcut: '@plugin',
|
||||
title: 'Search Plugins',
|
||||
description: 'Search plugins',
|
||||
search: jest.fn(),
|
||||
search: vi.fn(),
|
||||
},
|
||||
node: {
|
||||
key: '@node',
|
||||
shortcut: '@node',
|
||||
title: 'Search Nodes',
|
||||
description: 'Search workflow nodes',
|
||||
search: jest.fn(),
|
||||
search: vi.fn(),
|
||||
},
|
||||
}
|
||||
|
||||
const mockOnCommandSelect = jest.fn()
|
||||
const mockOnCommandValueChange = jest.fn()
|
||||
const mockOnCommandSelect = vi.fn()
|
||||
const mockOnCommandValueChange = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Basic Rendering', () => {
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import type { Mock } from 'vitest'
|
||||
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
|
||||
|
||||
// Mock the entire actions module to avoid import issues
|
||||
jest.mock('../../app/components/goto-anything/actions', () => ({
|
||||
matchAction: jest.fn(),
|
||||
vi.mock('../../app/components/goto-anything/actions', () => ({
|
||||
matchAction: vi.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('../../app/components/goto-anything/actions/commands/registry')
|
||||
vi.mock('../../app/components/goto-anything/actions/commands/registry')
|
||||
|
||||
// Import after mocking to get mocked version
|
||||
import { matchAction } from '../../app/components/goto-anything/actions'
|
||||
|
|
@ -39,7 +40,7 @@ const actualMatchAction = (query: string, actions: Record<string, ActionItem>) =
|
|||
}
|
||||
|
||||
// Replace mock with actual implementation
|
||||
;(matchAction as jest.Mock).mockImplementation(actualMatchAction)
|
||||
;(matchAction as Mock).mockImplementation(actualMatchAction)
|
||||
|
||||
describe('matchAction Logic', () => {
|
||||
const mockActions: Record<string, ActionItem> = {
|
||||
|
|
@ -48,27 +49,27 @@ describe('matchAction Logic', () => {
|
|||
shortcut: '@a',
|
||||
title: 'Search Applications',
|
||||
description: 'Search apps',
|
||||
search: jest.fn(),
|
||||
search: vi.fn(),
|
||||
},
|
||||
knowledge: {
|
||||
key: '@knowledge',
|
||||
shortcut: '@kb',
|
||||
title: 'Search Knowledge',
|
||||
description: 'Search knowledge bases',
|
||||
search: jest.fn(),
|
||||
search: vi.fn(),
|
||||
},
|
||||
slash: {
|
||||
key: '/',
|
||||
shortcut: '/',
|
||||
title: 'Commands',
|
||||
description: 'Execute commands',
|
||||
search: jest.fn(),
|
||||
search: vi.fn(),
|
||||
},
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([
|
||||
vi.clearAllMocks()
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
|
||||
{ name: 'docs', mode: 'direct' },
|
||||
{ name: 'community', mode: 'direct' },
|
||||
{ name: 'feedback', mode: 'direct' },
|
||||
|
|
@ -188,7 +189,7 @@ describe('matchAction Logic', () => {
|
|||
|
||||
describe('Mode-based Filtering', () => {
|
||||
it('should filter direct mode commands from matching', () => {
|
||||
;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
|
||||
{ name: 'test', mode: 'direct' },
|
||||
])
|
||||
|
||||
|
|
@ -197,7 +198,7 @@ describe('matchAction Logic', () => {
|
|||
})
|
||||
|
||||
it('should allow submenu mode commands to match', () => {
|
||||
;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
|
||||
{ name: 'test', mode: 'submenu' },
|
||||
])
|
||||
|
||||
|
|
@ -206,7 +207,7 @@ describe('matchAction Logic', () => {
|
|||
})
|
||||
|
||||
it('should treat undefined mode as submenu', () => {
|
||||
;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
|
||||
{ name: 'test' }, // No mode specified
|
||||
])
|
||||
|
||||
|
|
@ -227,7 +228,7 @@ describe('matchAction Logic', () => {
|
|||
})
|
||||
|
||||
it('should handle empty command list', () => {
|
||||
;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([])
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([])
|
||||
const result = matchAction('/anything', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import React from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import '@testing-library/jest-dom'
|
||||
|
||||
// Type alias for search mode
|
||||
type SearchMode = 'scopes' | 'commands' | null
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import type { MockedFunction } from 'vitest'
|
||||
/**
|
||||
* Test GotoAnything search error handling mechanisms
|
||||
*
|
||||
|
|
@ -14,33 +15,33 @@ import { fetchAppList } from '@/service/apps'
|
|||
import { fetchDatasets } from '@/service/datasets'
|
||||
|
||||
// Mock API functions
|
||||
jest.mock('@/service/base', () => ({
|
||||
postMarketplace: jest.fn(),
|
||||
vi.mock('@/service/base', () => ({
|
||||
postMarketplace: vi.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/apps', () => ({
|
||||
fetchAppList: jest.fn(),
|
||||
vi.mock('@/service/apps', () => ({
|
||||
fetchAppList: vi.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/datasets', () => ({
|
||||
fetchDatasets: jest.fn(),
|
||||
vi.mock('@/service/datasets', () => ({
|
||||
fetchDatasets: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockPostMarketplace = postMarketplace as jest.MockedFunction<typeof postMarketplace>
|
||||
const mockFetchAppList = fetchAppList as jest.MockedFunction<typeof fetchAppList>
|
||||
const mockFetchDatasets = fetchDatasets as jest.MockedFunction<typeof fetchDatasets>
|
||||
const mockPostMarketplace = postMarketplace as MockedFunction<typeof postMarketplace>
|
||||
const mockFetchAppList = fetchAppList as MockedFunction<typeof fetchAppList>
|
||||
const mockFetchDatasets = fetchDatasets as MockedFunction<typeof fetchDatasets>
|
||||
|
||||
describe('GotoAnything Search Error Handling', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
// Suppress console.warn for clean test output
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {
|
||||
// Suppress console.warn for clean test output
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
jest.restoreAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('@plugin search error handling', () => {
|
||||
|
|
|
|||
|
|
@ -1,17 +1,16 @@
|
|||
import '@testing-library/jest-dom'
|
||||
import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry'
|
||||
import type { SlashCommandHandler } from '../../app/components/goto-anything/actions/commands/types'
|
||||
|
||||
// Mock the registry
|
||||
jest.mock('../../app/components/goto-anything/actions/commands/registry')
|
||||
vi.mock('../../app/components/goto-anything/actions/commands/registry')
|
||||
|
||||
describe('Slash Command Dual-Mode System', () => {
|
||||
const mockDirectCommand: SlashCommandHandler = {
|
||||
name: 'docs',
|
||||
description: 'Open documentation',
|
||||
mode: 'direct',
|
||||
execute: jest.fn(),
|
||||
search: jest.fn().mockResolvedValue([
|
||||
execute: vi.fn(),
|
||||
search: vi.fn().mockResolvedValue([
|
||||
{
|
||||
id: 'docs',
|
||||
title: 'Documentation',
|
||||
|
|
@ -20,15 +19,15 @@ describe('Slash Command Dual-Mode System', () => {
|
|||
data: { command: 'navigation.docs', args: {} },
|
||||
},
|
||||
]),
|
||||
register: jest.fn(),
|
||||
unregister: jest.fn(),
|
||||
register: vi.fn(),
|
||||
unregister: vi.fn(),
|
||||
}
|
||||
|
||||
const mockSubmenuCommand: SlashCommandHandler = {
|
||||
name: 'theme',
|
||||
description: 'Change theme',
|
||||
mode: 'submenu',
|
||||
search: jest.fn().mockResolvedValue([
|
||||
search: vi.fn().mockResolvedValue([
|
||||
{
|
||||
id: 'theme-light',
|
||||
title: 'Light Theme',
|
||||
|
|
@ -44,18 +43,18 @@ describe('Slash Command Dual-Mode System', () => {
|
|||
data: { command: 'theme.set', args: { theme: 'dark' } },
|
||||
},
|
||||
]),
|
||||
register: jest.fn(),
|
||||
unregister: jest.fn(),
|
||||
register: vi.fn(),
|
||||
unregister: vi.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
;(slashCommandRegistry as any).findCommand = jest.fn((name: string) => {
|
||||
vi.clearAllMocks()
|
||||
;(slashCommandRegistry as any).findCommand = vi.fn((name: string) => {
|
||||
if (name === 'docs') return mockDirectCommand
|
||||
if (name === 'theme') return mockSubmenuCommand
|
||||
return null
|
||||
})
|
||||
;(slashCommandRegistry as any).getAllCommands = jest.fn(() => [
|
||||
;(slashCommandRegistry as any).getAllCommands = vi.fn(() => [
|
||||
mockDirectCommand,
|
||||
mockSubmenuCommand,
|
||||
])
|
||||
|
|
@ -63,8 +62,8 @@ describe('Slash Command Dual-Mode System', () => {
|
|||
|
||||
describe('Direct Mode Commands', () => {
|
||||
it('should execute immediately when selected', () => {
|
||||
const mockSetShow = jest.fn()
|
||||
const mockSetSearchQuery = jest.fn()
|
||||
const mockSetShow = vi.fn()
|
||||
const mockSetSearchQuery = vi.fn()
|
||||
|
||||
// Simulate command selection
|
||||
const handler = slashCommandRegistry.findCommand('docs')
|
||||
|
|
@ -88,7 +87,7 @@ describe('Slash Command Dual-Mode System', () => {
|
|||
})
|
||||
|
||||
it('should close modal after execution', () => {
|
||||
const mockModalClose = jest.fn()
|
||||
const mockModalClose = vi.fn()
|
||||
|
||||
const handler = slashCommandRegistry.findCommand('docs')
|
||||
if (handler?.mode === 'direct' && handler.execute) {
|
||||
|
|
@ -118,7 +117,7 @@ describe('Slash Command Dual-Mode System', () => {
|
|||
})
|
||||
|
||||
it('should keep modal open for selection', () => {
|
||||
const mockModalClose = jest.fn()
|
||||
const mockModalClose = vi.fn()
|
||||
|
||||
const handler = slashCommandRegistry.findCommand('theme')
|
||||
// For submenu mode, modal should not close immediately
|
||||
|
|
@ -141,12 +140,12 @@ describe('Slash Command Dual-Mode System', () => {
|
|||
const commandWithoutMode: SlashCommandHandler = {
|
||||
name: 'test',
|
||||
description: 'Test command',
|
||||
search: jest.fn(),
|
||||
register: jest.fn(),
|
||||
unregister: jest.fn(),
|
||||
search: vi.fn(),
|
||||
register: vi.fn(),
|
||||
unregister: vi.fn(),
|
||||
}
|
||||
|
||||
;(slashCommandRegistry as any).findCommand = jest.fn(() => commandWithoutMode)
|
||||
;(slashCommandRegistry as any).findCommand = vi.fn(() => commandWithoutMode)
|
||||
|
||||
const handler = slashCommandRegistry.findCommand('test')
|
||||
// Default behavior should be submenu when mode is not specified
|
||||
|
|
@ -189,7 +188,7 @@ describe('Slash Command Dual-Mode System', () => {
|
|||
describe('Command Registration', () => {
|
||||
it('should register both direct and submenu commands', () => {
|
||||
mockDirectCommand.register?.({})
|
||||
mockSubmenuCommand.register?.({ setTheme: jest.fn() })
|
||||
mockSubmenuCommand.register?.({ setTheme: vi.fn() })
|
||||
|
||||
expect(mockDirectCommand.register).toHaveBeenCalled()
|
||||
expect(mockSubmenuCommand.register).toHaveBeenCalled()
|
||||
|
|
|
|||
|
|
@ -15,12 +15,12 @@ import {
|
|||
} from '@/utils/navigation'
|
||||
|
||||
// Mock router for testing
|
||||
const mockPush = jest.fn()
|
||||
const mockPush = vi.fn()
|
||||
const mockRouter = { push: mockPush }
|
||||
|
||||
describe('Navigation Utilities', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('createNavigationPath', () => {
|
||||
|
|
@ -63,7 +63,7 @@ describe('Navigation Utilities', () => {
|
|||
configurable: true,
|
||||
})
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => { /* noop */ })
|
||||
const path = createNavigationPath('/datasets/123/documents')
|
||||
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
|
|
@ -134,7 +134,7 @@ describe('Navigation Utilities', () => {
|
|||
configurable: true,
|
||||
})
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => { /* noop */ })
|
||||
const params = extractQueryParams(['page', 'limit'])
|
||||
|
||||
expect(params).toEqual({})
|
||||
|
|
@ -169,11 +169,11 @@ describe('Navigation Utilities', () => {
|
|||
test('handles errors gracefully', () => {
|
||||
// Mock URLSearchParams to throw an error
|
||||
const originalURLSearchParams = globalThis.URLSearchParams
|
||||
globalThis.URLSearchParams = jest.fn(() => {
|
||||
globalThis.URLSearchParams = vi.fn(() => {
|
||||
throw new Error('URLSearchParams error')
|
||||
}) as any
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => { /* noop */ })
|
||||
const path = createNavigationPathWithParams('/datasets/123/documents', { page: 1 })
|
||||
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa
|
|||
return mediaQueryList
|
||||
}
|
||||
|
||||
jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia)
|
||||
vi.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia)
|
||||
}
|
||||
|
||||
// Helper function to create timing page component
|
||||
|
|
@ -240,8 +240,8 @@ const TestThemeProvider = ({ children }: { children: React.ReactNode }) => (
|
|||
|
||||
describe('Real Browser Environment Dark Mode Flicker Test', () => {
|
||||
beforeEach(() => {
|
||||
jest.restoreAllMocks()
|
||||
jest.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
vi.clearAllMocks()
|
||||
if (typeof window !== 'undefined') {
|
||||
try {
|
||||
window.localStorage.clear()
|
||||
|
|
@ -424,12 +424,12 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
|
|||
setupMockEnvironment(null)
|
||||
|
||||
const mockStorage = {
|
||||
getItem: jest.fn(() => {
|
||||
getItem: vi.fn(() => {
|
||||
throw new Error('LocalStorage access denied')
|
||||
}),
|
||||
setItem: jest.fn(),
|
||||
removeItem: jest.fn(),
|
||||
clear: jest.fn(),
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
}
|
||||
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
import type { Mock } from 'vitest'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useWorkflowStore } from '@/app/components/workflow/store'
|
||||
|
||||
// Type for mocked store
|
||||
type MockWorkflowStore = {
|
||||
showOnboarding: boolean
|
||||
setShowOnboarding: jest.Mock
|
||||
setShowOnboarding: Mock
|
||||
hasShownOnboarding: boolean
|
||||
setHasShownOnboarding: jest.Mock
|
||||
setHasShownOnboarding: Mock
|
||||
hasSelectedStartNode: boolean
|
||||
setHasSelectedStartNode: jest.Mock
|
||||
setShouldAutoOpenStartNodeSelector: jest.Mock
|
||||
setHasSelectedStartNode: Mock
|
||||
setShouldAutoOpenStartNodeSelector: Mock
|
||||
notInitialWorkflow: boolean
|
||||
}
|
||||
|
||||
|
|
@ -20,11 +21,11 @@ type MockNode = {
|
|||
}
|
||||
|
||||
// Mock zustand store
|
||||
jest.mock('@/app/components/workflow/store')
|
||||
vi.mock('@/app/components/workflow/store')
|
||||
|
||||
// Mock ReactFlow store
|
||||
const mockGetNodes = jest.fn()
|
||||
jest.mock('reactflow', () => ({
|
||||
const mockGetNodes = vi.fn()
|
||||
vi.mock('reactflow', () => ({
|
||||
useStoreApi: () => ({
|
||||
getState: () => ({
|
||||
getNodes: mockGetNodes,
|
||||
|
|
@ -33,16 +34,16 @@ jest.mock('reactflow', () => ({
|
|||
}))
|
||||
|
||||
describe('Workflow Onboarding Integration Logic', () => {
|
||||
const mockSetShowOnboarding = jest.fn()
|
||||
const mockSetHasSelectedStartNode = jest.fn()
|
||||
const mockSetHasShownOnboarding = jest.fn()
|
||||
const mockSetShouldAutoOpenStartNodeSelector = jest.fn()
|
||||
const mockSetShowOnboarding = vi.fn()
|
||||
const mockSetHasSelectedStartNode = vi.fn()
|
||||
const mockSetHasShownOnboarding = vi.fn()
|
||||
const mockSetShouldAutoOpenStartNodeSelector = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Mock store implementation
|
||||
;(useWorkflowStore as jest.Mock).mockReturnValue({
|
||||
;(useWorkflowStore as Mock).mockReturnValue({
|
||||
showOnboarding: false,
|
||||
setShowOnboarding: mockSetShowOnboarding,
|
||||
hasSelectedStartNode: false,
|
||||
|
|
@ -373,12 +374,12 @@ describe('Workflow Onboarding Integration Logic', () => {
|
|||
it('should trigger onboarding for new workflow when draft does not exist', () => {
|
||||
// Simulate the error handling logic from use-workflow-init.ts
|
||||
const error = {
|
||||
json: jest.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }),
|
||||
json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }),
|
||||
bodyUsed: false,
|
||||
}
|
||||
|
||||
const mockWorkflowStore = {
|
||||
setState: jest.fn(),
|
||||
setState: vi.fn(),
|
||||
}
|
||||
|
||||
// Simulate error handling
|
||||
|
|
@ -404,7 +405,7 @@ describe('Workflow Onboarding Integration Logic', () => {
|
|||
it('should not trigger onboarding for existing workflows', () => {
|
||||
// Simulate successful draft fetch
|
||||
const mockWorkflowStore = {
|
||||
setState: jest.fn(),
|
||||
setState: vi.fn(),
|
||||
}
|
||||
|
||||
// Normal initialization path should not set showOnboarding: true
|
||||
|
|
@ -419,7 +420,7 @@ describe('Workflow Onboarding Integration Logic', () => {
|
|||
})
|
||||
|
||||
it('should create empty draft with proper structure', () => {
|
||||
const mockSyncWorkflowDraft = jest.fn()
|
||||
const mockSyncWorkflowDraft = vi.fn()
|
||||
const appId = 'test-app-id'
|
||||
|
||||
// Simulate the syncWorkflowDraft call from use-workflow-init.ts
|
||||
|
|
@ -467,7 +468,7 @@ describe('Workflow Onboarding Integration Logic', () => {
|
|||
mockGetNodes.mockReturnValue([])
|
||||
|
||||
// Mock store with proper state for auto-detection
|
||||
;(useWorkflowStore as jest.Mock).mockReturnValue({
|
||||
;(useWorkflowStore as Mock).mockReturnValue({
|
||||
showOnboarding: false,
|
||||
hasShownOnboarding: false,
|
||||
notInitialWorkflow: false,
|
||||
|
|
@ -550,7 +551,7 @@ describe('Workflow Onboarding Integration Logic', () => {
|
|||
mockGetNodes.mockReturnValue([])
|
||||
|
||||
// Mock store with hasShownOnboarding = true
|
||||
;(useWorkflowStore as jest.Mock).mockReturnValue({
|
||||
;(useWorkflowStore as Mock).mockReturnValue({
|
||||
showOnboarding: false,
|
||||
hasShownOnboarding: true, // Already shown in this session
|
||||
notInitialWorkflow: false,
|
||||
|
|
@ -584,7 +585,7 @@ describe('Workflow Onboarding Integration Logic', () => {
|
|||
mockGetNodes.mockReturnValue([])
|
||||
|
||||
// Mock store with notInitialWorkflow = true (initial creation)
|
||||
;(useWorkflowStore as jest.Mock).mockReturnValue({
|
||||
;(useWorkflowStore as Mock).mockReturnValue({
|
||||
showOnboarding: false,
|
||||
hasShownOnboarding: false,
|
||||
notInitialWorkflow: true, // Initial workflow creation
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ function setupEnvironment(value?: string) {
|
|||
delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT
|
||||
|
||||
// Clear module cache to force re-evaluation
|
||||
jest.resetModules()
|
||||
vi.resetModules()
|
||||
}
|
||||
|
||||
function restoreEnvironment() {
|
||||
|
|
@ -28,11 +28,11 @@ function restoreEnvironment() {
|
|||
else
|
||||
delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT
|
||||
|
||||
jest.resetModules()
|
||||
vi.resetModules()
|
||||
}
|
||||
|
||||
// Mock i18next with proper implementation
|
||||
jest.mock('react-i18next', () => ({
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
if (key.includes('MaxParallelismTitle')) return 'Max Parallelism'
|
||||
|
|
@ -45,20 +45,20 @@ jest.mock('react-i18next', () => ({
|
|||
}),
|
||||
initReactI18next: {
|
||||
type: '3rdParty',
|
||||
init: jest.fn(),
|
||||
init: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock i18next module completely to prevent initialization issues
|
||||
jest.mock('i18next', () => ({
|
||||
use: jest.fn().mockReturnThis(),
|
||||
init: jest.fn().mockReturnThis(),
|
||||
t: jest.fn(key => key),
|
||||
vi.mock('i18next', () => ({
|
||||
use: vi.fn().mockReturnThis(),
|
||||
init: vi.fn().mockReturnThis(),
|
||||
t: vi.fn(key => key),
|
||||
isInitialized: true,
|
||||
}))
|
||||
|
||||
// Mock the useConfig hook
|
||||
jest.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({
|
||||
vi.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({
|
||||
inputs: {
|
||||
|
|
@ -66,82 +66,39 @@ jest.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({
|
|||
parallel_nums: 5,
|
||||
error_handle_mode: 'terminated',
|
||||
},
|
||||
changeParallel: jest.fn(),
|
||||
changeParallelNums: jest.fn(),
|
||||
changeErrorHandleMode: jest.fn(),
|
||||
changeParallel: vi.fn(),
|
||||
changeParallelNums: vi.fn(),
|
||||
changeErrorHandleMode: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock other components
|
||||
jest.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => {
|
||||
return function MockVarReferencePicker() {
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({
|
||||
default: function MockVarReferencePicker() {
|
||||
return <div data-testid="var-reference-picker">VarReferencePicker</div>
|
||||
}
|
||||
})
|
||||
},
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/workflow/nodes/_base/components/split', () => {
|
||||
return function MockSplit() {
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/split', () => ({
|
||||
default: function MockSplit() {
|
||||
return <div data-testid="split">Split</div>
|
||||
}
|
||||
})
|
||||
},
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/workflow/nodes/_base/components/field', () => {
|
||||
return function MockField({ title, children }: { title: string, children: React.ReactNode }) {
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/field', () => ({
|
||||
default: function MockField({ title, children }: { title: string, children: React.ReactNode }) {
|
||||
return (
|
||||
<div data-testid="field">
|
||||
<label>{title}</label>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
})
|
||||
},
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/base/switch', () => {
|
||||
return function MockSwitch({ defaultValue }: { defaultValue: boolean }) {
|
||||
return <input type="checkbox" defaultChecked={defaultValue} data-testid="switch" />
|
||||
}
|
||||
})
|
||||
|
||||
jest.mock('@/app/components/base/select', () => {
|
||||
return function MockSelect() {
|
||||
return <select data-testid="select">Select</select>
|
||||
}
|
||||
})
|
||||
|
||||
// Use defaultValue to avoid controlled input warnings
|
||||
jest.mock('@/app/components/base/slider', () => {
|
||||
return function MockSlider({ value, max, min }: { value: number, max: number, min: number }) {
|
||||
return (
|
||||
<input
|
||||
type="range"
|
||||
defaultValue={value}
|
||||
max={max}
|
||||
min={min}
|
||||
data-testid="slider"
|
||||
data-max={max}
|
||||
data-min={min}
|
||||
readOnly
|
||||
/>
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
// Use defaultValue to avoid controlled input warnings
|
||||
jest.mock('@/app/components/base/input', () => {
|
||||
return function MockInput({ type, max, min, value }: { type: string, max: number, min: number, value: number }) {
|
||||
return (
|
||||
<input
|
||||
type={type}
|
||||
defaultValue={value}
|
||||
max={max}
|
||||
min={min}
|
||||
data-testid="number-input"
|
||||
data-max={max}
|
||||
data-min={min}
|
||||
readOnly
|
||||
/>
|
||||
)
|
||||
}
|
||||
const getParallelControls = () => ({
|
||||
numberInput: screen.getByRole('spinbutton'),
|
||||
slider: screen.getByRole('slider'),
|
||||
})
|
||||
|
||||
describe('MAX_PARALLEL_LIMIT Configuration Bug', () => {
|
||||
|
|
@ -160,7 +117,7 @@ describe('MAX_PARALLEL_LIMIT Configuration Bug', () => {
|
|||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
|
|
@ -172,115 +129,114 @@ describe('MAX_PARALLEL_LIMIT Configuration Bug', () => {
|
|||
})
|
||||
|
||||
describe('Environment Variable Parsing', () => {
|
||||
it('should parse MAX_PARALLEL_LIMIT from NEXT_PUBLIC_MAX_PARALLEL_LIMIT environment variable', () => {
|
||||
it('should parse MAX_PARALLEL_LIMIT from NEXT_PUBLIC_MAX_PARALLEL_LIMIT environment variable', async () => {
|
||||
setupEnvironment('25')
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(25)
|
||||
})
|
||||
|
||||
it('should fallback to default when environment variable is not set', () => {
|
||||
it('should fallback to default when environment variable is not set', async () => {
|
||||
setupEnvironment() // No environment variable
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(10)
|
||||
})
|
||||
|
||||
it('should handle invalid environment variable values', () => {
|
||||
it('should handle invalid environment variable values', async () => {
|
||||
setupEnvironment('invalid')
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
|
||||
// Should fall back to default when parsing fails
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(10)
|
||||
})
|
||||
|
||||
it('should handle empty environment variable', () => {
|
||||
it('should handle empty environment variable', async () => {
|
||||
setupEnvironment('')
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
|
||||
// Should fall back to default when empty
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(10)
|
||||
})
|
||||
|
||||
// Edge cases for boundary values
|
||||
it('should clamp MAX_PARALLEL_LIMIT to MIN when env is 0 or negative', () => {
|
||||
it('should clamp MAX_PARALLEL_LIMIT to MIN when env is 0 or negative', async () => {
|
||||
setupEnvironment('0')
|
||||
let { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
let { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(10) // Falls back to default
|
||||
|
||||
setupEnvironment('-5')
|
||||
;({ MAX_PARALLEL_LIMIT } = require('@/config'))
|
||||
;({ MAX_PARALLEL_LIMIT } = await import('@/config'))
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(10) // Falls back to default
|
||||
})
|
||||
|
||||
it('should handle float numbers by parseInt behavior', () => {
|
||||
it('should handle float numbers by parseInt behavior', async () => {
|
||||
setupEnvironment('12.7')
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
// parseInt truncates to integer
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(12)
|
||||
})
|
||||
})
|
||||
|
||||
describe('UI Component Integration (Main Fix Verification)', () => {
|
||||
it('should render iteration panel with environment-configured max value', () => {
|
||||
it('should render iteration panel with environment-configured max value', async () => {
|
||||
// Set environment variable to a different value
|
||||
setupEnvironment('30')
|
||||
|
||||
// Import Panel after setting environment
|
||||
const Panel = require('@/app/components/workflow/nodes/iteration/panel').default
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const Panel = await import('@/app/components/workflow/nodes/iteration/panel').then(mod => mod.default)
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
|
||||
render(
|
||||
<Panel
|
||||
id="test-node"
|
||||
// @ts-expect-error key type mismatch
|
||||
data={mockNodeData.data}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Behavior-focused assertion: UI max should equal MAX_PARALLEL_LIMIT
|
||||
const numberInput = screen.getByTestId('number-input')
|
||||
expect(numberInput).toHaveAttribute('data-max', String(MAX_PARALLEL_LIMIT))
|
||||
|
||||
const slider = screen.getByTestId('slider')
|
||||
expect(slider).toHaveAttribute('data-max', String(MAX_PARALLEL_LIMIT))
|
||||
const { numberInput, slider } = getParallelControls()
|
||||
expect(numberInput).toHaveAttribute('max', String(MAX_PARALLEL_LIMIT))
|
||||
expect(slider).toHaveAttribute('aria-valuemax', String(MAX_PARALLEL_LIMIT))
|
||||
|
||||
// Verify the actual values
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(30)
|
||||
expect(numberInput.getAttribute('data-max')).toBe('30')
|
||||
expect(slider.getAttribute('data-max')).toBe('30')
|
||||
expect(numberInput.getAttribute('max')).toBe('30')
|
||||
expect(slider.getAttribute('aria-valuemax')).toBe('30')
|
||||
})
|
||||
|
||||
it('should maintain UI consistency with different environment values', () => {
|
||||
it('should maintain UI consistency with different environment values', async () => {
|
||||
setupEnvironment('15')
|
||||
const Panel = require('@/app/components/workflow/nodes/iteration/panel').default
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const Panel = await import('@/app/components/workflow/nodes/iteration/panel').then(mod => mod.default)
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
|
||||
render(
|
||||
<Panel
|
||||
id="test-node"
|
||||
// @ts-expect-error key type mismatch
|
||||
data={mockNodeData.data}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Both input and slider should use the same max value from MAX_PARALLEL_LIMIT
|
||||
const numberInput = screen.getByTestId('number-input')
|
||||
const slider = screen.getByTestId('slider')
|
||||
const { numberInput, slider } = getParallelControls()
|
||||
|
||||
expect(numberInput.getAttribute('data-max')).toBe(slider.getAttribute('data-max'))
|
||||
expect(numberInput.getAttribute('data-max')).toBe(String(MAX_PARALLEL_LIMIT))
|
||||
expect(numberInput.getAttribute('max')).toBe(slider.getAttribute('aria-valuemax'))
|
||||
expect(numberInput.getAttribute('max')).toBe(String(MAX_PARALLEL_LIMIT))
|
||||
})
|
||||
})
|
||||
|
||||
describe('Legacy Constant Verification (For Transition Period)', () => {
|
||||
// Marked as transition/deprecation tests
|
||||
it('should maintain MAX_ITERATION_PARALLEL_NUM for backward compatibility', () => {
|
||||
const { MAX_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants')
|
||||
it('should maintain MAX_ITERATION_PARALLEL_NUM for backward compatibility', async () => {
|
||||
const { MAX_ITERATION_PARALLEL_NUM } = await import('@/app/components/workflow/constants')
|
||||
expect(typeof MAX_ITERATION_PARALLEL_NUM).toBe('number')
|
||||
expect(MAX_ITERATION_PARALLEL_NUM).toBe(10) // Hardcoded legacy value
|
||||
})
|
||||
|
||||
it('should demonstrate MAX_PARALLEL_LIMIT vs legacy constant difference', () => {
|
||||
it('should demonstrate MAX_PARALLEL_LIMIT vs legacy constant difference', async () => {
|
||||
setupEnvironment('50')
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const { MAX_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants')
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
const { MAX_ITERATION_PARALLEL_NUM } = await import('@/app/components/workflow/constants')
|
||||
|
||||
// MAX_PARALLEL_LIMIT is configurable, MAX_ITERATION_PARALLEL_NUM is not
|
||||
expect(MAX_PARALLEL_LIMIT).toBe(50)
|
||||
|
|
@ -290,9 +246,9 @@ describe('MAX_PARALLEL_LIMIT Configuration Bug', () => {
|
|||
})
|
||||
|
||||
describe('Constants Validation', () => {
|
||||
it('should validate that required constants exist and have correct types', () => {
|
||||
const { MAX_PARALLEL_LIMIT } = require('@/config')
|
||||
const { MIN_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants')
|
||||
it('should validate that required constants exist and have correct types', async () => {
|
||||
const { MAX_PARALLEL_LIMIT } = await import('@/config')
|
||||
const { MIN_ITERATION_PARALLEL_NUM } = await import('@/app/components/workflow/constants')
|
||||
expect(typeof MAX_PARALLEL_LIMIT).toBe('number')
|
||||
expect(typeof MIN_ITERATION_PARALLEL_NUM).toBe('number')
|
||||
expect(MAX_PARALLEL_LIMIT).toBeGreaterThanOrEqual(MIN_ITERATION_PARALLEL_NUM)
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@
|
|||
|
||||
import React from 'react'
|
||||
import { cleanup, render } from '@testing-library/react'
|
||||
import '@testing-library/jest-dom'
|
||||
import BlockInput from '../app/components/base/block-input'
|
||||
import SupportVarInput from '../app/components/workflow/nodes/_base/components/support-var-input'
|
||||
|
||||
// Mock styles
|
||||
jest.mock('../app/components/app/configuration/base/var-highlight/style.module.css', () => ({
|
||||
item: 'mock-item-class',
|
||||
vi.mock('../app/components/app/configuration/base/var-highlight/style.module.css', () => ({
|
||||
default: {
|
||||
item: 'mock-item-class',
|
||||
},
|
||||
}))
|
||||
|
||||
describe('XSS Prevention - Block Input and Support Var Input Security', () => {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import {
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import s from './style.module.css'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useStore } from '@/app/components/app/store'
|
||||
import AppSideBar from '@/app/components/app-sidebar'
|
||||
import type { NavIcon } from '@/app/components/app-sidebar/navLink'
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react'
|
|||
import type { Dayjs } from 'dayjs'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { formatToLocalTime } from '@/utils/format'
|
||||
import { useI18N } from '@/context/i18n'
|
||||
import Picker from '@/app/components/base/date-and-time-picker/date-picker'
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select'
|
|||
import type { Item } from '@/app/components/base/select'
|
||||
import dayjs from 'dayjs'
|
||||
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const today = dayjs()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import React from 'react'
|
||||
import { render } from '@testing-library/react'
|
||||
import '@testing-library/jest-dom'
|
||||
import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { normalizeAttrs } from '@/app/components/base/icons/utils'
|
||||
import iconData from '@/app/components/base/icons/src/public/tracing/OpikIconBig.json'
|
||||
|
||||
describe('SVG Attribute Error Reproduction', () => {
|
||||
// Capture console errors
|
||||
|
|
@ -10,7 +11,7 @@ describe('SVG Attribute Error Reproduction', () => {
|
|||
|
||||
beforeEach(() => {
|
||||
errorMessages = []
|
||||
console.error = jest.fn((message) => {
|
||||
console.error = vi.fn((message) => {
|
||||
errorMessages.push(message)
|
||||
originalError(message)
|
||||
})
|
||||
|
|
@ -54,9 +55,6 @@ describe('SVG Attribute Error Reproduction', () => {
|
|||
it('should analyze the SVG structure causing the errors', () => {
|
||||
console.log('\n=== ANALYZING SVG STRUCTURE ===')
|
||||
|
||||
// Import the JSON data directly
|
||||
const iconData = require('@/app/components/base/icons/src/public/tracing/OpikIconBig.json')
|
||||
|
||||
console.log('Icon structure analysis:')
|
||||
console.log('- Root element:', iconData.icon.name)
|
||||
console.log('- Children count:', iconData.icon.children?.length || 0)
|
||||
|
|
@ -113,8 +111,6 @@ describe('SVG Attribute Error Reproduction', () => {
|
|||
it('should test the normalizeAttrs function behavior', () => {
|
||||
console.log('\n=== TESTING normalizeAttrs FUNCTION ===')
|
||||
|
||||
const { normalizeAttrs } = require('@/app/components/base/icons/utils')
|
||||
|
||||
const testAttributes = {
|
||||
'inkscape:showpageshadow': '2',
|
||||
'inkscape:pageopacity': '0.0',
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react'
|
|||
|
||||
import type { PopupProps } from './config-popup'
|
||||
import ConfigPopup from './config-popup'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator'
|
|||
import Switch from '@/app/components/base/switch'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const I18N_PREFIX = 'app.tracing'
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Input from '@/app/components/base/input'
|
||||
|
||||
type Props = {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS
|
|||
import { TracingProvider } from './type'
|
||||
import TracingIcon from './tracing-icon'
|
||||
import ConfigButton from './config-button'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {
|
|||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { TracingProvider } from './type'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing'
|
||||
|
||||
type Props = {
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use
|
|||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import ExtraInfo from '@/app/components/datasets/extra-info'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export type IAppDetailLayoutProps = {
|
||||
children: React.ReactNode
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client'
|
||||
import Header from '@/app/signin/_header'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
export default function SignInLayout({ children }: any) {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import cn from 'classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { RiCheckboxCircleFill } from '@remixicon/react'
|
||||
import { useCountDown } from 'ahooks'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
'use client'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import type { PropsWithChildren } from 'react'
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue