diff --git a/.claude/skills/frontend-testing/SKILL.md b/.claude/skills/frontend-testing/SKILL.md index 06cb672141..7475513ba0 100644 --- a/.claude/skills/frontend-testing/SKILL.md +++ b/.claude/skills/frontend-testing/SKILL.md @@ -1,13 +1,13 @@ --- -name: Dify Frontend Testing -description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests. +name: frontend-testing +description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests. --- # Dify Frontend Testing Skill This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. -> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification. +> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`). ## When to Apply This Skill @@ -15,7 +15,7 @@ Apply this skill when the user: - Asks to **write tests** for a component, hook, or utility - Asks to **review existing tests** for completeness -- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files** +- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files** - Requests **test coverage** improvement - Uses `pnpm analyze-component` output as context - Mentions **testing**, **unit tests**, or **integration tests** for frontend code @@ -33,9 +33,9 @@ Apply this skill when the user: | Tool | Version | Purpose | |------|---------|---------| -| Jest | 29.7 | Test runner | +| Vitest | 4.0.16 | Test runner | | React Testing Library | 16.0 | Component testing | -| happy-dom | - | Test environment | +| jsdom | - | Test environment | | nock | 14.0 | HTTP mocking | | TypeScript | 5.x | Type safety | @@ -46,7 +46,7 @@ Apply this skill when the user: pnpm test # Watch mode -pnpm test -- --watch +pnpm test:watch # Run specific file pnpm test -- path/to/file.spec.tsx @@ -77,9 +77,9 @@ import Component from './index' // import { ChildComponent } from './child-component' // ✅ Mock external dependencies only -jest.mock('@/service/api') -jest.mock('next/navigation', () => ({ - useRouter: () => ({ push: jest.fn() }), +vi.mock('@/service/api') +vi.mock('next/navigation', () => ({ + useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) @@ -88,7 +88,7 @@ let mockSharedState = false describe('ComponentName', () => { beforeEach(() => { - jest.clearAllMocks() // ✅ Reset mocks BEFORE each test + vi.clearAllMocks() // ✅ Reset mocks BEFORE each test mockSharedState = false // ✅ Reset shared state }) @@ -117,7 +117,7 @@ describe('ComponentName', () => { // User Interactions describe('User Interactions', () => { it('should handle click events', () => { - const handleClick = jest.fn() + const handleClick = vi.fn() render() fireEvent.click(screen.getByRole('button')) @@ -178,7 +178,7 @@ Process in this order for multi-file testing: - **500+ lines**: Consider splitting before testing - **Many dependencies**: Extract logic into hooks first -> 📖 See `guides/workflow.md` for complete workflow details and todo list format. +> 📖 See `references/workflow.md` for complete workflow details and todo list format. ## Testing Strategy @@ -289,17 +289,18 @@ For each test file generated, aim for: - ✅ **>95%** branch coverage - ✅ **>95%** line coverage -> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`. +> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`. ## Detailed Guides For more detailed information, refer to: -- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing) -- `guides/mocking.md` - Mock patterns and best practices -- `guides/async-testing.md` - Async operations and API calls -- `guides/domain-components.md` - Workflow, Dataset, Configuration testing -- `guides/common-patterns.md` - Frequently used testing patterns +- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing) +- `references/mocking.md` - Mock patterns and best practices +- `references/async-testing.md` - Async operations and API calls +- `references/domain-components.md` - Workflow, Dataset, Configuration testing +- `references/common-patterns.md` - Frequently used testing patterns +- `references/checklist.md` - Test generation checklist and validation steps ## Authoritative References @@ -315,7 +316,7 @@ For more detailed information, refer to: ### Project Configuration -- `web/jest.config.ts` - Jest configuration -- `web/jest.setup.ts` - Test environment setup +- `web/vitest.config.ts` - Vitest configuration +- `web/vitest.setup.ts` - Test environment setup - `web/testing/analyze-component.js` - Component analysis tool -- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations) +- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files. diff --git a/.claude/skills/frontend-testing/templates/component-test.template.tsx b/.claude/skills/frontend-testing/assets/component-test.template.tsx similarity index 96% rename from .claude/skills/frontend-testing/templates/component-test.template.tsx rename to .claude/skills/frontend-testing/assets/component-test.template.tsx index f1ea71a3fd..92dd797c83 100644 --- a/.claude/skills/frontend-testing/templates/component-test.template.tsx +++ b/.claude/skills/frontend-testing/assets/component-test.template.tsx @@ -23,14 +23,14 @@ import userEvent from '@testing-library/user-event' // ============================================================================ // Mocks // ============================================================================ -// WHY: Mocks must be hoisted to top of file (Jest requirement). +// WHY: Mocks must be hoisted to top of file (Vitest requirement). // They run BEFORE imports, so keep them before component imports. // i18n (automatically mocked) -// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest +// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup // No explicit mock needed - it returns translation keys as-is // Override only if custom translations are required: -// jest.mock('react-i18next', () => ({ +// vi.mock('react-i18next', () => ({ // useTranslation: () => ({ // t: (key: string) => { // const customTranslations: Record = { @@ -43,17 +43,17 @@ import userEvent from '@testing-library/user-event' // Router (if component uses useRouter, usePathname, useSearchParams) // WHY: Isolates tests from Next.js routing, enables testing navigation behavior -// const mockPush = jest.fn() -// jest.mock('next/navigation', () => ({ +// const mockPush = vi.fn() +// vi.mock('next/navigation', () => ({ // useRouter: () => ({ push: mockPush }), // usePathname: () => '/test-path', // })) // API services (if component fetches data) // WHY: Prevents real network calls, enables testing all states (loading/success/error) -// jest.mock('@/service/api') +// vi.mock('@/service/api') // import * as api from '@/service/api' -// const mockedApi = api as jest.Mocked +// const mockedApi = vi.mocked(api) // Shared mock state (for portal/dropdown components) // WHY: Portal components like PortalToFollowElem need shared state between @@ -98,7 +98,7 @@ describe('ComponentName', () => { // - Prevents mock call history from leaking between tests // - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() // Reset shared mock state if used (CRITICAL for portal/dropdown tests) // mockOpenState = false }) @@ -155,7 +155,7 @@ describe('ComponentName', () => { // - userEvent simulates real user behavior (focus, hover, then click) // - fireEvent is lower-level, doesn't trigger all browser events // const user = userEvent.setup() - // const handleClick = jest.fn() + // const handleClick = vi.fn() // render() // // await user.click(screen.getByRole('button')) @@ -165,7 +165,7 @@ describe('ComponentName', () => { it('should call onChange when value changes', async () => { // const user = userEvent.setup() - // const handleChange = jest.fn() + // const handleChange = vi.fn() // render() // // await user.type(screen.getByRole('textbox'), 'new value') diff --git a/.claude/skills/frontend-testing/templates/hook-test.template.ts b/.claude/skills/frontend-testing/assets/hook-test.template.ts similarity index 95% rename from .claude/skills/frontend-testing/templates/hook-test.template.ts rename to .claude/skills/frontend-testing/assets/hook-test.template.ts index 4fb7fd21ec..99161848a4 100644 --- a/.claude/skills/frontend-testing/templates/hook-test.template.ts +++ b/.claude/skills/frontend-testing/assets/hook-test.template.ts @@ -15,9 +15,9 @@ import { renderHook, act, waitFor } from '@testing-library/react' // ============================================================================ // API services (if hook fetches data) -// jest.mock('@/service/api') +// vi.mock('@/service/api') // import * as api from '@/service/api' -// const mockedApi = api as jest.Mocked +// const mockedApi = vi.mocked(api) // ============================================================================ // Test Helpers @@ -38,7 +38,7 @@ import { renderHook, act, waitFor } from '@testing-library/react' describe('useHookName', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) // -------------------------------------------------------------------------- @@ -145,7 +145,7 @@ describe('useHookName', () => { // -------------------------------------------------------------------------- describe('Side Effects', () => { it('should call callback when value changes', () => { - // const callback = jest.fn() + // const callback = vi.fn() // const { result } = renderHook(() => useHookName({ onChange: callback })) // // act(() => { @@ -156,9 +156,9 @@ describe('useHookName', () => { }) it('should cleanup on unmount', () => { - // const cleanup = jest.fn() - // jest.spyOn(window, 'addEventListener') - // jest.spyOn(window, 'removeEventListener') + // const cleanup = vi.fn() + // vi.spyOn(window, 'addEventListener') + // vi.spyOn(window, 'removeEventListener') // // const { unmount } = renderHook(() => useHookName()) // diff --git a/.claude/skills/frontend-testing/templates/utility-test.template.ts b/.claude/skills/frontend-testing/assets/utility-test.template.ts similarity index 100% rename from .claude/skills/frontend-testing/templates/utility-test.template.ts rename to .claude/skills/frontend-testing/assets/utility-test.template.ts diff --git a/.claude/skills/frontend-testing/guides/async-testing.md b/.claude/skills/frontend-testing/references/async-testing.md similarity index 92% rename from .claude/skills/frontend-testing/guides/async-testing.md rename to .claude/skills/frontend-testing/references/async-testing.md index f9912debbf..ae775a87a9 100644 --- a/.claude/skills/frontend-testing/guides/async-testing.md +++ b/.claude/skills/frontend-testing/references/async-testing.md @@ -49,7 +49,7 @@ import userEvent from '@testing-library/user-event' it('should submit form', async () => { const user = userEvent.setup() - const onSubmit = jest.fn() + const onSubmit = vi.fn() render(
) @@ -77,15 +77,15 @@ it('should submit form', async () => { ```typescript describe('Debounced Search', () => { beforeEach(() => { - jest.useFakeTimers() + vi.useFakeTimers() }) afterEach(() => { - jest.useRealTimers() + vi.useRealTimers() }) it('should debounce search input', async () => { - const onSearch = jest.fn() + const onSearch = vi.fn() render() // Type in the input @@ -95,7 +95,7 @@ describe('Debounced Search', () => { expect(onSearch).not.toHaveBeenCalled() // Advance timers - jest.advanceTimersByTime(300) + vi.advanceTimersByTime(300) // Now search is called expect(onSearch).toHaveBeenCalledWith('query') @@ -107,8 +107,8 @@ describe('Debounced Search', () => { ```typescript it('should retry on failure', async () => { - jest.useFakeTimers() - const fetchData = jest.fn() + vi.useFakeTimers() + const fetchData = vi.fn() .mockRejectedValueOnce(new Error('Network error')) .mockResolvedValueOnce({ data: 'success' }) @@ -120,7 +120,7 @@ it('should retry on failure', async () => { }) // Advance timer for retry - jest.advanceTimersByTime(1000) + vi.advanceTimersByTime(1000) // Second call succeeds await waitFor(() => { @@ -128,7 +128,7 @@ it('should retry on failure', async () => { expect(screen.getByText('success')).toBeInTheDocument() }) - jest.useRealTimers() + vi.useRealTimers() }) ``` @@ -136,19 +136,19 @@ it('should retry on failure', async () => { ```typescript // Run all pending timers -jest.runAllTimers() +vi.runAllTimers() // Run only pending timers (not new ones created during execution) -jest.runOnlyPendingTimers() +vi.runOnlyPendingTimers() // Advance by specific time -jest.advanceTimersByTime(1000) +vi.advanceTimersByTime(1000) // Get current fake time -jest.now() +Date.now() // Clear all timers -jest.clearAllTimers() +vi.clearAllTimers() ``` ## API Testing Patterns @@ -158,7 +158,7 @@ jest.clearAllTimers() ```typescript describe('DataFetcher', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) it('should show loading state', () => { @@ -241,7 +241,7 @@ it('should submit form and show success', async () => { ```typescript it('should fetch data on mount', async () => { - const fetchData = jest.fn().mockResolvedValue({ data: 'test' }) + const fetchData = vi.fn().mockResolvedValue({ data: 'test' }) render() @@ -255,7 +255,7 @@ it('should fetch data on mount', async () => { ```typescript it('should refetch when id changes', async () => { - const fetchData = jest.fn().mockResolvedValue({ data: 'test' }) + const fetchData = vi.fn().mockResolvedValue({ data: 'test' }) const { rerender } = render() @@ -276,8 +276,8 @@ it('should refetch when id changes', async () => { ```typescript it('should cleanup subscription on unmount', () => { - const subscribe = jest.fn() - const unsubscribe = jest.fn() + const subscribe = vi.fn() + const unsubscribe = vi.fn() subscribe.mockReturnValue(unsubscribe) const { unmount } = render() @@ -332,14 +332,14 @@ expect(description).toBeInTheDocument() ```typescript // Bad - fake timers don't work well with real Promises -jest.useFakeTimers() +vi.useFakeTimers() await waitFor(() => { expect(screen.getByText('Data')).toBeInTheDocument() }) // May timeout! // Good - use runAllTimers or advanceTimersByTime -jest.useFakeTimers() +vi.useFakeTimers() render() -jest.runAllTimers() +vi.runAllTimers() expect(screen.getByText('Data')).toBeInTheDocument() ``` diff --git a/.claude/skills/frontend-testing/CHECKLIST.md b/.claude/skills/frontend-testing/references/checklist.md similarity index 93% rename from .claude/skills/frontend-testing/CHECKLIST.md rename to .claude/skills/frontend-testing/references/checklist.md index b960067264..aad80b120e 100644 --- a/.claude/skills/frontend-testing/CHECKLIST.md +++ b/.claude/skills/frontend-testing/references/checklist.md @@ -74,9 +74,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Mocks - [ ] **DO NOT mock base components** (`@/app/components/base/*`) -- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`) +- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`) - [ ] Shared mock state reset in `beforeEach` -- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations +- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations - [ ] Router mocks match actual Next.js API - [ ] Mocks reflect actual component conditional behavior - [ ] Only mock: API services, complex context providers, third-party libs @@ -132,10 +132,10 @@ For the current file being tested: ```typescript // ❌ Mock doesn't match actual behavior -jest.mock('./Component', () => () =>
Mocked
) +vi.mock('./Component', () => () =>
Mocked
) // ✅ Mock matches actual conditional logic -jest.mock('./Component', () => ({ isOpen }: any) => +vi.mock('./Component', () => ({ isOpen }: any) => isOpen ?
Content
: null ) ``` @@ -145,7 +145,7 @@ jest.mock('./Component', () => ({ isOpen }: any) => ```typescript // ❌ Shared state not reset let mockState = false -jest.mock('./useHook', () => () => mockState) +vi.mock('./useHook', () => () => mockState) // ✅ Reset in beforeEach beforeEach(() => { @@ -192,7 +192,7 @@ pnpm test -- path/to/file.spec.tsx pnpm test -- --coverage path/to/file.spec.tsx # Watch mode -pnpm test -- --watch path/to/file.spec.tsx +pnpm test:watch -- path/to/file.spec.tsx # Update snapshots (use sparingly) pnpm test -- -u path/to/file.spec.tsx diff --git a/.claude/skills/frontend-testing/guides/common-patterns.md b/.claude/skills/frontend-testing/references/common-patterns.md similarity index 94% rename from .claude/skills/frontend-testing/guides/common-patterns.md rename to .claude/skills/frontend-testing/references/common-patterns.md index 84a6045b04..6eded5ceba 100644 --- a/.claude/skills/frontend-testing/guides/common-patterns.md +++ b/.claude/skills/frontend-testing/references/common-patterns.md @@ -126,7 +126,7 @@ describe('Counter', () => { describe('ControlledInput', () => { it('should call onChange with new value', async () => { const user = userEvent.setup() - const handleChange = jest.fn() + const handleChange = vi.fn() render() @@ -136,7 +136,7 @@ describe('ControlledInput', () => { }) it('should display controlled value', () => { - render() + render() expect(screen.getByRole('textbox')).toHaveValue('controlled') }) @@ -195,7 +195,7 @@ describe('ItemList', () => { it('should handle item selection', async () => { const user = userEvent.setup() - const onSelect = jest.fn() + const onSelect = vi.fn() render() @@ -217,20 +217,20 @@ describe('ItemList', () => { ```typescript describe('Modal', () => { it('should not render when closed', () => { - render() + render() expect(screen.queryByRole('dialog')).not.toBeInTheDocument() }) it('should render when open', () => { - render() + render() expect(screen.getByRole('dialog')).toBeInTheDocument() }) it('should call onClose when clicking overlay', async () => { const user = userEvent.setup() - const handleClose = jest.fn() + const handleClose = vi.fn() render() @@ -241,7 +241,7 @@ describe('Modal', () => { it('should call onClose when pressing Escape', async () => { const user = userEvent.setup() - const handleClose = jest.fn() + const handleClose = vi.fn() render() @@ -254,7 +254,7 @@ describe('Modal', () => { const user = userEvent.setup() render( - + @@ -279,7 +279,7 @@ describe('Modal', () => { describe('LoginForm', () => { it('should submit valid form', async () => { const user = userEvent.setup() - const onSubmit = jest.fn() + const onSubmit = vi.fn() render() @@ -296,7 +296,7 @@ describe('LoginForm', () => { it('should show validation errors', async () => { const user = userEvent.setup() - render() + render() // Submit empty form await user.click(screen.getByRole('button', { name: /sign in/i })) @@ -308,7 +308,7 @@ describe('LoginForm', () => { it('should validate email format', async () => { const user = userEvent.setup() - render() + render() await user.type(screen.getByLabelText(/email/i), 'invalid-email') await user.click(screen.getByRole('button', { name: /sign in/i })) @@ -318,7 +318,7 @@ describe('LoginForm', () => { it('should disable submit button while submitting', async () => { const user = userEvent.setup() - const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100))) + const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100))) render() @@ -407,7 +407,7 @@ it('test 1', () => { // Good - cleanup is automatic with RTL, but reset mocks beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) ``` diff --git a/.claude/skills/frontend-testing/guides/domain-components.md b/.claude/skills/frontend-testing/references/domain-components.md similarity index 95% rename from .claude/skills/frontend-testing/guides/domain-components.md rename to .claude/skills/frontend-testing/references/domain-components.md index ed2cc6eb8a..5535d28f3d 100644 --- a/.claude/skills/frontend-testing/guides/domain-components.md +++ b/.claude/skills/frontend-testing/references/domain-components.md @@ -23,7 +23,7 @@ import NodeConfigPanel from './node-config-panel' import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow' // Mock workflow context -jest.mock('@/app/components/workflow/hooks', () => ({ +vi.mock('@/app/components/workflow/hooks', () => ({ useWorkflowStore: () => mockWorkflowStore, useNodesInteractions: () => mockNodesInteractions, })) @@ -31,21 +31,21 @@ jest.mock('@/app/components/workflow/hooks', () => ({ let mockWorkflowStore = { nodes: [], edges: [], - updateNode: jest.fn(), + updateNode: vi.fn(), } let mockNodesInteractions = { - handleNodeSelect: jest.fn(), - handleNodeDelete: jest.fn(), + handleNodeSelect: vi.fn(), + handleNodeDelete: vi.fn(), } describe('NodeConfigPanel', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() mockWorkflowStore = { nodes: [], edges: [], - updateNode: jest.fn(), + updateNode: vi.fn(), } }) @@ -161,23 +161,23 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import DocumentUploader from './document-uploader' -jest.mock('@/service/datasets', () => ({ - uploadDocument: jest.fn(), - parseDocument: jest.fn(), +vi.mock('@/service/datasets', () => ({ + uploadDocument: vi.fn(), + parseDocument: vi.fn(), })) import * as datasetService from '@/service/datasets' -const mockedService = datasetService as jest.Mocked +const mockedService = vi.mocked(datasetService) describe('DocumentUploader', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) describe('File Upload', () => { it('should accept valid file types', async () => { const user = userEvent.setup() - const onUpload = jest.fn() + const onUpload = vi.fn() mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' }) render() @@ -326,14 +326,14 @@ describe('DocumentList', () => { describe('Search & Filtering', () => { it('should filter by search query', async () => { const user = userEvent.setup() - jest.useFakeTimers() + vi.useFakeTimers() render() await user.type(screen.getByPlaceholderText(/search/i), 'test query') // Debounce - jest.advanceTimersByTime(300) + vi.advanceTimersByTime(300) await waitFor(() => { expect(mockedService.getDocuments).toHaveBeenCalledWith( @@ -342,7 +342,7 @@ describe('DocumentList', () => { ) }) - jest.useRealTimers() + vi.useRealTimers() }) }) }) @@ -367,13 +367,13 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import AppConfigForm from './app-config-form' -jest.mock('@/service/apps', () => ({ - updateAppConfig: jest.fn(), - getAppConfig: jest.fn(), +vi.mock('@/service/apps', () => ({ + updateAppConfig: vi.fn(), + getAppConfig: vi.fn(), })) import * as appService from '@/service/apps' -const mockedService = appService as jest.Mocked +const mockedService = vi.mocked(appService) describe('AppConfigForm', () => { const defaultConfig = { @@ -384,7 +384,7 @@ describe('AppConfigForm', () => { } beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() mockedService.getAppConfig.mockResolvedValue(defaultConfig) }) diff --git a/.claude/skills/frontend-testing/guides/mocking.md b/.claude/skills/frontend-testing/references/mocking.md similarity index 88% rename from .claude/skills/frontend-testing/guides/mocking.md rename to .claude/skills/frontend-testing/references/mocking.md index bf0bd79690..51920ebc64 100644 --- a/.claude/skills/frontend-testing/guides/mocking.md +++ b/.claude/skills/frontend-testing/references/mocking.md @@ -19,8 +19,8 @@ ```typescript // ❌ WRONG: Don't mock base components -jest.mock('@/app/components/base/loading', () => () =>
Loading
) -jest.mock('@/app/components/base/button', () => ({ children }: any) => ) +vi.mock('@/app/components/base/loading', () => () =>
Loading
) +vi.mock('@/app/components/base/button', () => ({ children }: any) => ) // ✅ CORRECT: Import and use real base components import Loading from '@/app/components/base/loading' @@ -41,20 +41,23 @@ Only mock these categories: | Location | Purpose | |----------|---------| -| `web/__mocks__/` | Reusable mocks shared across multiple test files | -| Test file | Test-specific mocks, inline with `jest.mock()` | +| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) | +| `web/__mocks__/` | Reusable mock factories shared across multiple test files | +| Test file | Test-specific mocks, inline with `vi.mock()` | + +Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`. ## Essential Mocks -### 1. i18n (Auto-loaded via Shared Mock) +### 1. i18n (Auto-loaded via Global Mock) -A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest. +A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup. **No explicit mock needed** for most tests - it returns translation keys as-is. For tests requiring custom translations, override the mock: ```typescript -jest.mock('react-i18next', () => ({ +vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string) => { const translations: Record = { @@ -69,15 +72,15 @@ jest.mock('react-i18next', () => ({ ### 2. Next.js Router ```typescript -const mockPush = jest.fn() -const mockReplace = jest.fn() +const mockPush = vi.fn() +const mockReplace = vi.fn() -jest.mock('next/navigation', () => ({ +vi.mock('next/navigation', () => ({ useRouter: () => ({ push: mockPush, replace: mockReplace, - back: jest.fn(), - prefetch: jest.fn(), + back: vi.fn(), + prefetch: vi.fn(), }), usePathname: () => '/current-path', useSearchParams: () => new URLSearchParams('?key=value'), @@ -85,7 +88,7 @@ jest.mock('next/navigation', () => ({ describe('Component', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) it('should navigate on click', () => { @@ -102,7 +105,7 @@ describe('Component', () => { // ⚠️ Important: Use shared state for components that depend on each other let mockPortalOpenState = false -jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ PortalToFollowElem: ({ children, open, ...props }: any) => { mockPortalOpenState = open || false // Update shared state return
{children}
@@ -119,7 +122,7 @@ jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ describe('Component', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() mockPortalOpenState = false // ✅ Reset shared state }) }) @@ -130,13 +133,13 @@ describe('Component', () => { ```typescript import * as api from '@/service/api' -jest.mock('@/service/api') +vi.mock('@/service/api') -const mockedApi = api as jest.Mocked +const mockedApi = vi.mocked(api) describe('Component', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() // Setup default mock implementation mockedApi.fetchData.mockResolvedValue({ data: [] }) @@ -243,13 +246,13 @@ describe('Component with Context', () => { ```typescript // SWR -jest.mock('swr', () => ({ +vi.mock('swr', () => ({ __esModule: true, - default: jest.fn(), + default: vi.fn(), })) import useSWR from 'swr' -const mockedUseSWR = useSWR as jest.Mock +const mockedUseSWR = vi.mocked(useSWR) describe('Component with SWR', () => { it('should show loading state', () => { diff --git a/.claude/skills/frontend-testing/guides/workflow.md b/.claude/skills/frontend-testing/references/workflow.md similarity index 100% rename from .claude/skills/frontend-testing/guides/workflow.md rename to .claude/skills/frontend-testing/references/workflow.md diff --git a/.codex/skills b/.codex/skills new file mode 120000 index 0000000000..454b8427cd --- /dev/null +++ b/.codex/skills @@ -0,0 +1 @@ +../.claude/skills \ No newline at end of file diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index a26fd076ed..ce9135476f 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -6,7 +6,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d6f326d4dc..4bc4f085c2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,6 +6,12 @@ * @crazywoola @laipz8200 @Yeuoly +# CODEOWNERS file +.github/CODEOWNERS @laipz8200 @crazywoola + +# Docs +docs/ @crazywoola + # Backend (default owner, more specific rules below will override) api/ @QuantumGhost @@ -116,11 +122,17 @@ api/controllers/console/feature.py @GarfieldDai @GareArc api/controllers/web/feature.py @GarfieldDai @GareArc # Backend - Database Migrations -api/migrations/ @snakevash @laipz8200 +api/migrations/ @snakevash @laipz8200 @MRZHUH + +# Backend - Vector DB Middleware +api/configs/middleware/vdb/* @JohnJyong # Frontend web/ @iamjoel +# Frontend - Web Tests +.github/workflows/web-tests.yml @iamjoel + # Frontend - App - Orchestration web/app/components/workflow/ @iamjoel @zxhlyh web/app/components/workflow-app/ @iamjoel @zxhlyh @@ -192,6 +204,7 @@ web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d web/app/signin/ @douxc @iamjoel web/app/signup/ @douxc @iamjoel web/app/reset-password/ @douxc @iamjoel + web/app/install/ @douxc @iamjoel web/app/init/ @douxc @iamjoel web/app/forgot-password/ @douxc @iamjoel @@ -232,3 +245,6 @@ web/app/education-apply/ @iamjoel @zxhlyh # Frontend - Workspace web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh + +# Docker +docker/* @laipz8200 diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index d7a58ce93d..bafac7bd13 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -66,7 +66,7 @@ jobs: # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - name: mdformat run: | - uvx --python 3.13 mdformat . --exclude ".claude/skills/**" + uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md" - name: Install pnpm uses: pnpm/action-setup@v4 @@ -79,7 +79,7 @@ jobs: with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Web dependencies working-directory: ./web diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 5a8a34be79..2fb8121f74 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -90,7 +90,7 @@ jobs: with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Web dependencies if: steps.changed-files.outputs.any_changed == 'true' diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index fe8e2ebc2b..8bb82d5d44 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -55,7 +55,7 @@ jobs: with: node-version: 'lts/*' cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Install dependencies if: env.FILES_CHANGED == 'true' diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 3313e58614..8eba0f084b 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -13,6 +13,7 @@ jobs: runs-on: ubuntu-latest defaults: run: + shell: bash working-directory: ./web steps: @@ -21,14 +22,7 @@ jobs: with: persist-credentials: false - - name: Check changed files - id: changed-files - uses: tj-actions/changed-files@v46 - with: - files: web/** - - name: Install pnpm - if: steps.changed-files.outputs.any_changed == 'true' uses: pnpm/action-setup@v4 with: package_json_file: web/package.json @@ -36,23 +30,342 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v4 - if: steps.changed-files.outputs.any_changed == 'true' with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Install dependencies - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web run: pnpm install --frozen-lockfile - name: Check i18n types synchronization - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web run: pnpm run check:i18n-types - name: Run tests - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: pnpm test + run: pnpm test --coverage + + - name: Coverage Summary + if: always() + id: coverage-summary + run: | + set -eo pipefail + + COVERAGE_FILE="coverage/coverage-final.json" + COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json" + + if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then + echo "has_coverage=false" >> "$GITHUB_OUTPUT" + echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY" + echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY" + exit 0 + fi + + echo "has_coverage=true" >> "$GITHUB_OUTPUT" + + node <<'NODE' >> "$GITHUB_STEP_SUMMARY" + const fs = require('fs'); + const path = require('path'); + let libCoverage = null; + + try { + libCoverage = require('istanbul-lib-coverage'); + } catch (error) { + libCoverage = null; + } + + const summaryPath = path.join('coverage', 'coverage-summary.json'); + const finalPath = path.join('coverage', 'coverage-final.json'); + + const hasSummary = fs.existsSync(summaryPath); + const hasFinal = fs.existsSync(finalPath); + + if (!hasSummary && !hasFinal) { + console.log('### Test Coverage Summary :test_tube:'); + console.log(''); + console.log('No coverage data found.'); + process.exit(0); + } + + const summary = hasSummary + ? JSON.parse(fs.readFileSync(summaryPath, 'utf8')) + : null; + const coverage = hasFinal + ? JSON.parse(fs.readFileSync(finalPath, 'utf8')) + : null; + + const getLineCoverageFromStatements = (statementMap, statementHits) => { + const lineHits = {}; + + if (!statementMap || !statementHits) { + return lineHits; + } + + Object.entries(statementMap).forEach(([key, statement]) => { + const line = statement?.start?.line; + if (!line) { + return; + } + const hits = statementHits[key] ?? 0; + const previous = lineHits[line]; + lineHits[line] = previous === undefined ? hits : Math.max(previous, hits); + }); + + return lineHits; + }; + + const getFileCoverage = (entry) => ( + libCoverage ? libCoverage.createFileCoverage(entry) : null + ); + + const getLineHits = (entry, fileCoverage) => { + const lineHits = entry.l ?? {}; + if (Object.keys(lineHits).length > 0) { + return lineHits; + } + if (fileCoverage) { + return fileCoverage.getLineCoverage(); + } + return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {}); + }; + + const getUncoveredLines = (entry, fileCoverage, lineHits) => { + if (lineHits && Object.keys(lineHits).length > 0) { + return Object.entries(lineHits) + .filter(([, count]) => count === 0) + .map(([line]) => Number(line)) + .sort((a, b) => a - b); + } + if (fileCoverage) { + return fileCoverage.getUncoveredLines(); + } + return []; + }; + + const totals = { + lines: { covered: 0, total: 0 }, + statements: { covered: 0, total: 0 }, + branches: { covered: 0, total: 0 }, + functions: { covered: 0, total: 0 }, + }; + const fileSummaries = []; + + if (summary) { + const totalEntry = summary.total ?? {}; + ['lines', 'statements', 'branches', 'functions'].forEach((key) => { + if (totalEntry[key]) { + totals[key].covered = totalEntry[key].covered ?? 0; + totals[key].total = totalEntry[key].total ?? 0; + } + }); + + Object.entries(summary) + .filter(([file]) => file !== 'total') + .forEach(([file, data]) => { + fileSummaries.push({ + file, + pct: data.lines?.pct ?? data.statements?.pct ?? 0, + lines: { + covered: data.lines?.covered ?? 0, + total: data.lines?.total ?? 0, + }, + }); + }); + } else if (coverage) { + Object.entries(coverage).forEach(([file, entry]) => { + const fileCoverage = getFileCoverage(entry); + const lineHits = getLineHits(entry, fileCoverage); + const statementHits = entry.s ?? {}; + const branchHits = entry.b ?? {}; + const functionHits = entry.f ?? {}; + + const lineTotal = Object.keys(lineHits).length; + const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; + + const statementTotal = Object.keys(statementHits).length; + const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; + + const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); + const branchCovered = Object.values(branchHits).reduce( + (acc, branches) => acc + branches.filter((n) => n > 0).length, + 0, + ); + + const functionTotal = Object.keys(functionHits).length; + const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; + + totals.lines.total += lineTotal; + totals.lines.covered += lineCovered; + totals.statements.total += statementTotal; + totals.statements.covered += statementCovered; + totals.branches.total += branchTotal; + totals.branches.covered += branchCovered; + totals.functions.total += functionTotal; + totals.functions.covered += functionCovered; + + const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0); + + fileSummaries.push({ + file, + pct: pct(lineCovered || statementCovered, lineTotal || statementTotal), + lines: { + covered: lineCovered || statementCovered, + total: lineTotal || statementTotal, + }, + }); + }); + } + + const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00'); + + console.log('### Test Coverage Summary :test_tube:'); + console.log(''); + console.log('| Metric | Coverage | Covered / Total |'); + console.log('|--------|----------|-----------------|'); + console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`); + console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`); + console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`); + console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`); + + console.log(''); + console.log('
File coverage (lowest lines first)'); + console.log(''); + console.log('```'); + fileSummaries + .sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total)) + .slice(0, 25) + .forEach(({ file, pct, lines }) => { + console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`); + }); + console.log('```'); + console.log('
'); + + if (coverage) { + const pctValue = (covered, tot) => { + if (tot === 0) { + return '0'; + } + return ((covered / tot) * 100) + .toFixed(2) + .replace(/\.?0+$/, ''); + }; + + const formatLineRanges = (lines) => { + if (lines.length === 0) { + return ''; + } + const ranges = []; + let start = lines[0]; + let end = lines[0]; + + for (let i = 1; i < lines.length; i += 1) { + const current = lines[i]; + if (current === end + 1) { + end = current; + continue; + } + ranges.push(start === end ? `${start}` : `${start}-${end}`); + start = current; + end = current; + } + ranges.push(start === end ? `${start}` : `${start}-${end}`); + return ranges.join(','); + }; + + const tableTotals = { + statements: { covered: 0, total: 0 }, + branches: { covered: 0, total: 0 }, + functions: { covered: 0, total: 0 }, + lines: { covered: 0, total: 0 }, + }; + const tableRows = Object.entries(coverage) + .map(([file, entry]) => { + const fileCoverage = getFileCoverage(entry); + const lineHits = getLineHits(entry, fileCoverage); + const statementHits = entry.s ?? {}; + const branchHits = entry.b ?? {}; + const functionHits = entry.f ?? {}; + + const lineTotal = Object.keys(lineHits).length; + const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; + const statementTotal = Object.keys(statementHits).length; + const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; + const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); + const branchCovered = Object.values(branchHits).reduce( + (acc, branches) => acc + branches.filter((n) => n > 0).length, + 0, + ); + const functionTotal = Object.keys(functionHits).length; + const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; + + tableTotals.lines.total += lineTotal; + tableTotals.lines.covered += lineCovered; + tableTotals.statements.total += statementTotal; + tableTotals.statements.covered += statementCovered; + tableTotals.branches.total += branchTotal; + tableTotals.branches.covered += branchCovered; + tableTotals.functions.total += functionTotal; + tableTotals.functions.covered += functionCovered; + + const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits); + + const filePath = entry.path ?? file; + const relativePath = path.isAbsolute(filePath) + ? path.relative(process.cwd(), filePath) + : filePath; + + return { + file: relativePath || file, + statements: pctValue(statementCovered, statementTotal), + branches: pctValue(branchCovered, branchTotal), + functions: pctValue(functionCovered, functionTotal), + lines: pctValue(lineCovered, lineTotal), + uncovered: formatLineRanges(uncoveredLines), + }; + }) + .sort((a, b) => a.file.localeCompare(b.file)); + + const columns = [ + { key: 'file', header: 'File', align: 'left' }, + { key: 'statements', header: '% Stmts', align: 'right' }, + { key: 'branches', header: '% Branch', align: 'right' }, + { key: 'functions', header: '% Funcs', align: 'right' }, + { key: 'lines', header: '% Lines', align: 'right' }, + { key: 'uncovered', header: 'Uncovered Line #s', align: 'left' }, + ]; + + const allFilesRow = { + file: 'All files', + statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total), + branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total), + functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total), + lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total), + uncovered: '', + }; + + const rowsForOutput = [allFilesRow, ...tableRows]; + const formatRow = (row) => `| ${columns + .map(({ key }) => String(row[key] ?? '')) + .join(' | ')} |`; + const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`; + const dividerRow = `| ${columns + .map(({ align }) => (align === 'right' ? '---:' : ':---')) + .join(' | ')} |`; + + console.log(''); + console.log('
Vitest coverage table'); + console.log(''); + console.log(headerRow); + console.log(dividerRow); + rowsForOutput.forEach((row) => console.log(formatRow(row))); + console.log('
'); + } + NODE + + - name: Upload Coverage Artifact + if: steps.coverage-summary.outputs.has_coverage == 'true' + uses: actions/upload-artifact@v4 + with: + name: web-coverage-report + path: web/coverage + retention-days: 30 + if-no-files-found: error diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index cb934d01b5..bdded1e73e 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -37,7 +37,7 @@ "-c", "1", "-Q", - "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor", + "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention", "--loglevel", "INFO" ], diff --git a/api/.env.example b/api/.env.example index 43fe76bb11..b87d9c7b02 100644 --- a/api/.env.example +++ b/api/.env.example @@ -690,3 +690,8 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 # Maximum number of concurrent annotation import tasks per tenant ANNOTATION_IMPORT_MAX_CONCURRENT=5 + +# Sandbox expired records clean configuration +SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 +SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 +SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 diff --git a/api/README.md b/api/README.md index 2dab2ec6e6..794b05d3af 100644 --- a/api/README.md +++ b/api/README.md @@ -84,7 +84,7 @@ 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash -uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor +uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention ``` Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index e16ca52f46..43dddbd011 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -218,7 +218,7 @@ class PluginConfig(BaseSettings): PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field( description="Timeout in seconds for requests to the plugin daemon (set to None to disable)", - default=300.0, + default=600.0, ) INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") @@ -1270,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings): ) +class SandboxExpiredRecordsCleanConfig(BaseSettings): + SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field( + description="Graceful period in days for sandbox records clean after subscription expiration", + default=21, + ) + SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field( + description="Maximum number of records to process in each batch", + default=1000, + ) + SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field( + description="Retention days for sandbox expired workflow_run records and message records", + default=30, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1295,6 +1310,7 @@ class FeatureConfig( PositionConfig, RagEtlConfig, RepositoryConfig, + SandboxExpiredRecordsCleanConfig, SecurityConfig, TenantIsolatedTaskQueueConfig, ToolConfig, diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ea21c4480d..8ceb896d4f 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: dict[str, Any] | None = None - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 5901eca915..a6e5b2822a 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -40,7 +40,7 @@ from .. import console_ns logger = logging.getLogger(__name__) -class CompletionMessagePayload(BaseModel): +class CompletionMessageExplorePayload(BaseModel): inputs: dict[str, Any] query: str = "" files: list[dict[str, Any]] | None = None @@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel): raise ValueError("must be a valid UUID") from exc -register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) +register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload) # define completion api for user @@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) endpoint="installed_app_completion", ) class CompletionApi(InstalledAppResource): - @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) + @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__]) def post(self, installed_app): app_model = installed_app.app if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - payload = CompletionMessagePayload.model_validate(console_ns.payload or {}) + payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) streaming = payload.response_mode == "streaming" diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 92da591ab4..51995b8b8a 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,5 +1,4 @@ from typing import Any -from uuid import UUID from flask import request from flask_restx import marshal_with @@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from libs.helper import UUIDStrOrEmpty from libs.login import current_user from models import Account from models.model import AppMode @@ -24,7 +24,7 @@ from .. import console_ns class ConversationListQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) pinned: bool | None = None diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3c95779475..e42db10ba6 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,7 +2,8 @@ import logging from typing import Any from flask import request -from flask_restx import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -18,6 +19,15 @@ from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService + +class InstalledAppCreatePayload(BaseModel): + app_id: str + + +class InstalledAppUpdatePayload(BaseModel): + is_pinned: bool | None = None + + logger = logging.getLogger(__name__) @@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") def post(self): - parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id") - args = parser.parse_args() + payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() if recommended_app is None: - raise NotFound("App not found") + raise NotFound("Recommended app not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == args["app_id"]).first() + app = db.session.query(App).where(App.id == payload.app_id).first() if app is None: - raise NotFound("App not found") + raise NotFound("App entity not found") if not app.is_public: raise Forbidden("You can't install a non-public app") installed_app = ( db.session.query(InstalledApp) - .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) .first() ) @@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource): recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args["app_id"], + app_id=payload.app_id, tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, @@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource): return {"result": "success", "message": "App uninstalled successfully"}, 204 def patch(self, installed_app): - parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean) - args = parser.parse_args() + payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {}) commit_args = False - if "is_pinned" in args: - installed_app.is_pinned = args["is_pinned"] + if payload.is_pinned is not None: + installed_app.is_pinned = payload.is_pinned commit_args = True if commit_args: diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 17cfc3ff4b..e9fbb515e4 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,31 +1,40 @@ +from typing import Literal + from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required -from models.model import Tag from services.tag_service import TagService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name +class TagBasePayload(BaseModel): + name: str = Field(description="Tag name", min_length=1, max_length=50) + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") -parser_tags = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=_validate_name, - ) - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") +class TagBindingPayload(BaseModel): + tag_ids: list[str] = Field(description="Tag IDs to bind") + target_id: str = Field(description="Target ID to bind tags to") + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +class TagBindingRemovePayload(BaseModel): + tag_id: str = Field(description="Tag ID to remove") + target_id: str = Field(description="Target ID to unbind tag from") + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +register_schema_models( + console_ns, + TagBasePayload, + TagBindingPayload, + TagBindingRemovePayload, ) @@ -43,7 +52,7 @@ class TagListApi(Resource): return tags, 200 - @console_ns.expect(parser_tags) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -53,22 +62,17 @@ class TagListApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tags.parse_args() - tag = TagService.save_tags(args) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.save_tags(payload.model_dump()) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 -parser_tag_id = reqparse.RequestParser().add_argument( - "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name -) - - @console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): - @console_ns.expect(parser_tag_id) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tag_id.parse_args() - tag = TagService.update_tags(args, tag_id) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.update_tags(payload.model_dump(), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource): return 204 -parser_create = ( - reqparse.RequestParser() - .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") - .add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/create") class TagBindingCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_create.parse_args() - TagService.save_tag_binding(args) + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding(payload.model_dump()) return {"result": "success"}, 200 -parser_remove = ( - reqparse.RequestParser() - .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/remove") class TagBindingDeleteApi(Resource): - @console_ns.expect(parser_remove) + @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_remove.parse_args() - TagService.delete_tag_binding(args) + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding(payload.model_dump()) return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2c54aa5a20..cb711d16e4 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -18,6 +18,7 @@ from controllers.console.wraps import ( setup_required, ) from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration +from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient @@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource): configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None - # Create provider + # Create provider in transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) result = service.create_provider( @@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource): configuration=configuration, authentication=authentication, ) - return jsonable_encoder(result) + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(tenant_id) + + return jsonable_encoder(result) @console_ns.expect(parser_mcp_put) @setup_required @@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource): authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None _, current_tenant_id = current_account_with_tenant() - # Step 1: Validate server URL change if needed (includes URL format validation and network operation) - validation_result = None + # Step 1: Get provider data for URL validation (short-lived session, no network I/O) + validation_data = None with Session(db.engine) as session: service = MCPToolManageService(session=session) - validation_result = service.validate_server_url_change( - tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + validation_data = service.get_provider_for_url_validation( + tenant_id=current_tenant_id, provider_id=args["provider_id"] ) - # No need to check for errors here, exceptions will be raised directly + # Step 2: Perform URL validation with network I/O OUTSIDE of any database session + # This prevents holding database locks during potentially slow network operations + validation_result = MCPToolManageService.validate_server_url_standalone( + tenant_id=current_tenant_id, + new_server_url=args["server_url"], + validation_data=validation_data, + ) - # Step 2: Perform database update in a transaction + # Step 3: Perform database update in a transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.update_provider( @@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource): authentication=authentication, validation_result=validation_result, ) - return {"result": "success"} + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(current_tenant_id) + + return {"result": "success"} @console_ns.expect(parser_mcp_delete) @setup_required @@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) - return {"result": "success"} + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(current_tenant_id) + + return {"result": "success"} parser_auth = ( @@ -1062,6 +1081,8 @@ class ToolMCPAuthApi(Resource): credentials=provider_entity.credentials, authed=True, ) + # Invalidate cache after updating credentials + ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} except MCPAuthError as e: try: @@ -1075,16 +1096,22 @@ class ToolMCPAuthApi(Resource): with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) response = service.execute_auth_actions(auth_result) + # Invalidate cache after auth actions may have updated provider state + ToolProviderListCache.invalidate_cache(tenant_id) return response except MCPRefreshTokenError as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) + # Invalidate cache after clearing credentials + ToolProviderListCache.invalidate_cache(tenant_id) raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e except (MCPError, ValueError) as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) + # Invalidate cache after clearing credentials + ToolProviderListCache.invalidate_cache(tenant_id) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 7692aeed23..4f91f40c55 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: RetrievalModel | None = None - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b9fef48c4d..15828cc208 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import fields, marshal_with, reqparse +from flask_restx import fields, marshal_with +from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError import services @@ -20,6 +21,7 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService from services.errors.audio import ( @@ -29,6 +31,25 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) +from ..common.schema import register_schema_models + + +class TextToAudioPayload(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +register_schema_models(web_ns, TextToAudioPayload) + logger = logging.getLogger(__name__) @@ -88,6 +109,7 @@ class AudioApi(WebApiResource): @web_ns.route("/text-to-audio") class TextApi(WebApiResource): + @web_ns.expect(web_ns.models[TextToAudioPayload.__name__]) @web_ns.doc("Text to Audio") @web_ns.doc(description="Convert text to audio using text-to-speech service.") @web_ns.doc( @@ -102,18 +124,11 @@ class TextApi(WebApiResource): def post(self, app_model: App, end_user): """Convert text to audio""" try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() + payload = TextToAudioPayload.model_validate(web_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e8a4698375..a97d745471 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,9 +1,11 @@ import logging +from typing import Any, Literal -from flask_restx import reqparse +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, @@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +class CompletionMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the completion") + query: str = Field(default="", description="Query text for completion") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + retriever_from: str = Field(default="web_app", description="Source of retriever") + + +class ChatMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the chat") + query: str = Field(description="User query/message") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + conversation_id: str | None = Field(default=None, description="Conversation ID") + parent_message_id: str | None = Field(default=None, description="Parent message ID") + retriever_from: str = Field(default="web_app", description="Source of retriever") + + @field_validator("conversation_id", "parent_message_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload) + + # define completion api for user @web_ns.route("/completion-messages") class CompletionApi(WebApiResource): @web_ns.doc("Create Completion Message") @web_ns.doc(description="Create a completion message for text generation applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, - "query": {"description": "Query text for completion", "type": "string", "required": False}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[CompletionMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -67,18 +88,10 @@ class CompletionApi(WebApiResource): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = CompletionMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: @@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): @web_ns.doc("Create Chat Message") @web_ns.doc(description="Create a chat message for conversational applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, - "query": {"description": "User query/message", "type": "string", "required": True}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "conversation_id": {"description": "Conversation UUID", "type": "string", "required": False}, - "parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False}, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[ChatMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -173,20 +171,10 @@ class ChatApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = ChatMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 93f2742599..307af3747c 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,3 +1,4 @@ +import json from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal @@ -120,7 +121,7 @@ class VariableEntity(BaseModel): allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) + json_schema: str | None = Field(default=None) @field_validator("description", mode="before") @classmethod @@ -134,11 +135,17 @@ class VariableEntity(BaseModel): @field_validator("json_schema") @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: + def validate_json_schema(cls, schema: str | None) -> str | None: if schema is None: return None + try: - Draft7Validator.check_schema(schema) + json_schema = json.loads(schema) + except json.JSONDecodeError: + raise ValueError(f"invalid json_schema value {schema}") + + try: + Draft7Validator.check_schema(json_schema) except SchemaError as e: raise ValueError(f"Invalid JSON schema: {e.message}") return schema diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 1b0474142e..02d58a07d1 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,3 +1,4 @@ +import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Union, final @@ -175,6 +176,13 @@ class BaseAppGenerator: value = True elif value == 0: value = False + case VariableEntityType.JSON_OBJECT: + if not isinstance(value, str): + raise ValueError(f"{variable_entity.variable} in input form must be a string") + try: + json.loads(value) + except json.JSONDecodeError: + raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object") case _: raise AssertionError("this statement should be unreachable.") diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5c169f4db1..5bb93fa44a 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -342,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): + event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, + event_type=event_type, ) else: yield self._agent_message_to_stream_response( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2e6f92efa5..0e7f300cee 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -5,7 +5,7 @@ from threading import Thread from typing import Union from flask import Flask, current_app -from sqlalchemy import select +from sqlalchemy import exists, select from sqlalchemy.orm import Session from configs import dify_config @@ -54,6 +54,20 @@ class MessageCycleManager: ): self._application_generate_entity = application_generate_entity self._task_state = task_state + self._message_has_file: set[str] = set() + + def get_message_event_type(self, message_id: str) -> StreamEvent: + if message_id in self._message_has_file: + return StreamEvent.MESSAGE_FILE + + with Session(db.engine, expire_on_commit=False) as session: + has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() + + if has_file: + self._message_has_file.add(message_id) + return StreamEvent.MESSAGE_FILE + + return StreamEvent.MESSAGE def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None: """ @@ -214,7 +228,11 @@ class MessageCycleManager: return None def message_to_stream_response( - self, answer: str, message_id: str, from_variable_selector: list[str] | None = None + self, + answer: str, + message_id: str, + from_variable_selector: list[str] | None = None, + event_type: StreamEvent | None = None, ) -> MessageStreamResponse: """ Message to stream response. @@ -222,16 +240,12 @@ class MessageCycleManager: :param message_id: message id :return: """ - with Session(db.engine, expire_on_commit=False) as session: - message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id)) - event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE - return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, from_variable_selector=from_variable_selector, - event=event_type, + event=event_type or StreamEvent.MESSAGE, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 92787b39dd..aef1afb235 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls( """ Build a list of URLs to try for Protected Resource Metadata discovery. - Per SEP-985, supports fallback when discovery fails at one URL. + Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL. + Priority order: + 1. URL from WWW-Authenticate header (if provided) + 2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp + 3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource """ urls = [] @@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls( # Fallback: construct from server URL parsed = urlparse(server_url) base_url = f"{parsed.scheme}://{parsed.netloc}" - fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource") - if fallback_url not in urls: - urls.append(fallback_url) + path = parsed.path.rstrip("/") + + # Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp) + if path: + path_url = f"{base_url}/.well-known/oauth-protected-resource{path}" + if path_url not in urls: + urls.append(path_url) + + # Priority 3: At root (e.g., /.well-known/oauth-protected-resource) + root_url = f"{base_url}/.well-known/oauth-protected-resource" + if root_url not in urls: + urls.append(root_url) return urls @@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery. - Per RFC 8414 section 3: - - If issuer has no path: https://example.com/.well-known/oauth-authorization-server - - If issuer has path: https://example.com/.well-known/oauth-authorization-server{path} - - Example: - - issuer: https://example.com/oauth - - metadata: https://example.com/.well-known/oauth-authorization-server/oauth + Per RFC 8414 section 3.1 and section 5, try all possible endpoints: + - OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1 + - OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1 + - OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration + - OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server + - OpenID Connect at root: https://example.com/.well-known/openid-configuration """ urls = [] base_url = auth_server_url or server_url parsed = urlparse(base_url) base = f"{parsed.scheme}://{parsed.netloc}" - path = parsed.path.rstrip("/") # Remove trailing slash + path = parsed.path.rstrip("/") + # OAuth 2.0 Authorization Server Metadata at root (MCP-03-26) + urls.append(f"{base}/.well-known/oauth-authorization-server") - # Try OpenID Connect discovery first (more common) - urls.append(urljoin(base + "/", ".well-known/openid-configuration")) + # OpenID Connect Discovery at root + urls.append(f"{base}/.well-known/openid-configuration") - # OAuth 2.0 Authorization Server Metadata (RFC 8414) - # Include the path component if present in the issuer URL if path: - urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}")) - else: - urls.append(urljoin(base, ".well-known/oauth-authorization-server")) + # OpenID Connect Discovery with path insertion + urls.append(f"{base}/.well-known/openid-configuration{path}") + + # OpenID Connect Discovery path appending + urls.append(f"{base}{path}/.well-known/openid-configuration") + + # OAuth 2.0 Authorization Server Metadata with path insertion + urls.append(f"{base}/.well-known/oauth-authorization-server{path}") return urls diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index b0e0dab9be..2b0645b558 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -59,7 +59,7 @@ class MCPClient: try: logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name) self.connect_server(sse_client, "sse") - except MCPConnectionError: + except (MCPConnectionError, ValueError): logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") self.connect_server(streamablehttp_client, "mcp") diff --git a/api/core/model_runtime/README.md b/api/core/model_runtime/README.md index a6caa7eb1e..b9d2c55210 100644 --- a/api/core/model_runtime/README.md +++ b/api/core/model_runtime/README.md @@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model - Model provider display - ![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png) - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md). + Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - Selectable model list display - ![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png) - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - In addition, this list also returns configurable parameter information and rules for LLM, as shown below: - - ![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png) - - These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule). + In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - Provider/model credential authentication - ![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png) - - ![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png) - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO. + The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. ## Structure -![](./docs/en_US/images/index/image-20231210165243632.png) - Model Runtime is divided into three layers: - The outermost layer is the factory method @@ -60,9 +46,6 @@ Model Runtime is divided into three layers: It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). -## Next Steps +## Documentation -- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md) -- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel) -- View YAML configuration rules: [Link](./docs/en_US/schema.md) -- Implement interface methods: [Link](./docs/en_US/interfaces.md) +For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/core/model_runtime/README_CN.md b/api/core/model_runtime/README_CN.md index dfe614347a..0a8b56b3fe 100644 --- a/api/core/model_runtime/README_CN.md +++ b/api/core/model_runtime/README_CN.md @@ -18,34 +18,20 @@ - 模型供应商展示 - ![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png) - -​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。 + 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - 可选择的模型列表展示 - ![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png) + 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 -​ 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - -​ 除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图: - -​ ![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png) - -​ 这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。 + 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - 供应商/模型凭据鉴权 - ![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png) - -![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png) - -​ 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。 + 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 ## 结构 -![](./docs/zh_Hans/images/index/image-20231210165243632.png) - Model Runtime 分三层: - 最外层为工厂方法 @@ -59,8 +45,7 @@ Model Runtime 分三层: 对于供应商/模型凭据,有两种情况 - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - ![Alt text](docs/zh_Hans/images/index/image.png) + - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 @@ -74,20 +59,6 @@ Model Runtime 分三层: - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 -## 下一步 +## 文档 -### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md) - -当添加后,这里将会出现一个新的供应商 - -![Alt text](docs/zh_Hans/images/index/image-1.png) - -### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B) - -当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。 - -![Alt text](docs/zh_Hans/images/index/image-2.png) - -### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md) - -你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 +有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index a1c84bd5d9..7bb2749afa 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -39,7 +39,7 @@ from core.trigger.errors import ( plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( float | httpx.Timeout | None, - getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0), + getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0), ) plugin_daemon_request_timeout: httpx.Timeout | None if _plugin_daemon_timeout_config is None: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 044b118635..f67f613e9d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor): def _extract_images_from_docx(self, doc): image_count = 0 image_map = {} + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL for r_id, rel in doc.part.rels.items(): if "image" in rel.target_ref: @@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor): used_at=naive_utc_now(), ) db.session.add(upload_file) - # Use r_id as key for external images since target_part is undefined - image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" + image_map[r_id] = f"![image]({base_url}/files/{upload_file.id}/file-preview)" else: image_ext = rel.target_ref.split(".")[-1] if image_ext is None: @@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor): used_at=naive_utc_now(), ) db.session.add(upload_file) - # Use target_part as key for internal images - image_map[rel.target_part] = ( - f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" - ) + image_map[rel.target_part] = f"![image]({base_url}/files/{upload_file.id}/file-preview)" db.session.commit() return image_map diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 801d2a2a52..b65cb14d8e 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,6 +2,7 @@ from __future__ import annotations +import codecs import re from typing import Any @@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) - self._fixed_separator = fixed_separator + self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape") self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""] def split_text(self, text: str) -> list[str]: @@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = re.split(r" +", text) else: splits = text.split(separator) - splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] + if self._keep_separator: + splits = [s + separator for s in splits[:-1]] + splits[-1:] else: splits = list(text) if separator == "\n": @@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits - _separator = separator if self._keep_separator else "" + _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) if separator != "": for s, s_len in zip(splits, s_lens): diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index f0c84872fb..931c6113a7 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -86,6 +86,11 @@ class Executor: node_data.authorization.config.api_key = variable_pool.convert_template( node_data.authorization.config.api_key ).text + # Validate that API key is not empty after template conversion + if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip(): + raise AuthorizationConfigError( + "API key is required for authorization but was empty. Please provide a valid API key." + ) self.url = node_data.url self.method = node_data.method diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 38effa79f7..36fc5078c5 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,3 +1,4 @@ +import json from typing import Any from jsonschema import Draft7Validator, ValidationError @@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]): if value is None and variable.required: raise ValueError(f"{key} is required in input form") - if not isinstance(value, dict): - raise ValueError(f"{key} must be a JSON object") - schema = variable.json_schema if not schema: continue + if not value: + continue + try: - Draft7Validator(schema).validate(value) + json_schema = json.loads(schema) + except json.JSONDecodeError as e: + raise ValueError(f"{schema} must be a valid JSON object") + + try: + json_value = json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"{value} must be a valid JSON object") + + try: + Draft7Validator(json_schema).validate(json_value) except ValidationError as e: raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") - node_inputs[key] = value + node_inputs[key] = json_value diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6313085e64..5a69eb15ac 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -69,6 +69,53 @@ if [[ "${MODE}" == "worker" ]]; then elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} + +elif [[ "${MODE}" == "job" ]]; then + # Job mode: Run a one-time Flask command and exit + # Pass Flask command and arguments via container args + # Example K8s usage: + # args: + # - create-tenant + # - --email + # - admin@example.com + # + # Example Docker usage: + # docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com + + if [[ $# -eq 0 ]]; then + echo "Error: No command specified for job mode." + echo "" + echo "Usage examples:" + echo " Kubernetes:" + echo " args: [create-tenant, --email, admin@example.com]" + echo "" + echo " Docker:" + echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com" + echo "" + echo "Available commands:" + echo " create-tenant, reset-password, reset-email, upgrade-db," + echo " vdb-migrate, install-plugins, and more..." + echo "" + echo "Run 'flask --help' to see all available commands." + exit 1 + fi + + echo "Running Flask job command: flask $*" + + # Temporarily disable exit on error to capture exit code + set +e + flask "$@" + JOB_EXIT_CODE=$? + set -e + + if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then + echo "Job completed successfully." + else + echo "Job failed with exit code ${JOB_EXIT_CODE}." + fi + + exit ${JOB_EXIT_CODE} + else if [[ "${DEBUG}" == "true" ]]; then exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index a084844d72..83c5c2d12f 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage): if not self.exists(path): raise FileNotFoundError("Path not found") - all_files = self.op.scan(path=path) + # Use the new OpenDAL 0.46.0+ API with recursive listing + lister = self.op.list(path, recursive=True) if files and directories: logger.debug("files and directories on %s scanned", path) - return [f.path for f in all_files] + return [entry.path for entry in lister] if files: logger.debug("files on %s scanned", path) - return [f.path for f in all_files if not f.path.endswith("/")] + return [entry.path for entry in lister if not entry.metadata.is_dir] elif directories: logger.debug("directories on %s scanned", path) - return [f.path for f in all_files if f.path.endswith("/")] + return [entry.path for entry in lister if entry.metadata.is_dir] else: raise ValueError("At least one of files or directories must be True") diff --git a/api/libs/helper.py b/api/libs/helper.py index 4a7afe0bda..74e1808e49 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast +from uuid import UUID from zoneinfo import available_timezones from flask import Response, stream_with_context @@ -119,6 +120,19 @@ def uuid_value(value: Any) -> str: raise ValueError(error) +def normalize_uuid(value: str | UUID) -> str: + if not value: + return "" + + try: + return uuid_value(value) + except ValueError as exc: + raise ValueError("must be a valid UUID") from exc + + +UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] + + def alphanumeric(value: str): # check if the value is alphanumeric and underlined if re.match(r"^[a-zA-Z0-9_]+$", value): diff --git a/api/pyproject.toml b/api/pyproject.toml index 6fcbc0f25c..870de33f4b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "bs4~=0.0.1", "cachetools~=5.3.0", "celery~=5.5.2", - "chardet~=5.1.0", + "charset-normalizer>=3.4.4", "flask~=3.1.2", "flask-compress>=1.17,<1.18", "flask-cors~=6.0.0", @@ -32,6 +32,7 @@ dependencies = [ "httpx[socks]~=0.27.0", "jieba==0.42.1", "json-repair>=0.41.1", + "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.1.77", "markdown~=3.5.1", diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54e1c9d285..3d7cb6cc8d 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,8 +1,12 @@ +import logging import os +from collections.abc import Sequence from typing import Literal import httpx +from pydantic import TypeAdapter from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed +from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError from enums.cloud_plan import CloudPlan @@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client from libs.helper import RateLimiter from models import Account, TenantAccountJoin, TenantAccountRole +logger = logging.getLogger(__name__) + + +class SubscriptionPlan(TypedDict): + """Tenant subscriptionplan information.""" + + plan: str + expiration_date: int + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") @@ -239,3 +252,39 @@ class BillingService: def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str): payload = {"account_id": account_id, "click_id": click_id} return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload) + + @classmethod + def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: + """ + Bulk fetch billing subscription plan via billing API. + Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request) + Returns: + Mapping of tenant_id -> {plan: str, expiration_date: int} + """ + results: dict[str, SubscriptionPlan] = {} + subscription_adapter = TypeAdapter(SubscriptionPlan) + + chunk_size = 200 + for i in range(0, len(tenant_ids), chunk_size): + chunk = tenant_ids[i : i + chunk_size] + try: + resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) + data = resp.get("data", {}) + + for tenant_id, plan in data.items(): + subscription_plan = subscription_adapter.validate_python(plan) + results[tenant_id] = subscription_plan + except Exception: + logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) + continue + + return results + + @classmethod + def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]: + resp = cls._send_request("GET", "/subscription/cleanup/whitelist") + data = resp.get("data", []) + tenant_whitelist = [] + for item in data: + tenant_whitelist.append(item["tenant_id"]) + return tenant_whitelist diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index a97ccab914..cbb0efcc2a 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel): description: str icon_info: IconInfo permission: str - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None yaml_content: str | None = None diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index d641fe0315..252be77b27 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -15,7 +15,6 @@ from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache -from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError @@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel): return self.needs_validation and self.validation_passed and self.reconnect_result is not None +class ProviderUrlValidationData(BaseModel): + """Data required for URL validation, extracted from database to perform network operations outside of session""" + + current_server_url_hash: str + headers: dict[str, str] + timeout: float | None + sse_read_timeout: float | None + + class MCPToolManageService: """Service class for managing MCP tools and providers.""" @@ -166,9 +174,6 @@ class MCPToolManageService: self._session.add(mcp_tool) self._session.flush() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) return mcp_providers @@ -192,7 +197,7 @@ class MCPToolManageService: Update an MCP provider. Args: - validation_result: Pre-validation result from validate_server_url_change. + validation_result: Pre-validation result from validate_server_url_standalone. If provided and contains reconnect_result, it will be used instead of performing network operations. """ @@ -251,8 +256,6 @@ class MCPToolManageService: # Flush changes to database self._session.flush() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) except IntegrityError as e: self._handle_integrity_error(e, name, server_url, server_identifier) @@ -261,9 +264,6 @@ class MCPToolManageService: mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) self._session.delete(mcp_tool) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - def list_providers( self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True ) -> list[ToolProviderApiEntity]: @@ -546,30 +546,39 @@ class MCPToolManageService: ) return self.execute_auth_actions(auth_result) - def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: - """Attempt to reconnect to MCP provider with new server URL.""" + def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData: + """ + Get provider data required for URL validation. + This method performs database read and should be called within a session. + + Returns: + ProviderUrlValidationData: Data needed for standalone URL validation + """ + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) provider_entity = provider.to_entity() - headers = provider_entity.headers + return ProviderUrlValidationData( + current_server_url_hash=provider.server_url_hash, + headers=provider_entity.headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + ) - try: - tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) - return ReconnectResult( - authed=True, - tools=json.dumps([tool.model_dump() for tool in tools]), - encrypted_credentials=EMPTY_CREDENTIALS_JSON, - ) - except MCPAuthError: - return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) - except MCPError as e: - raise ValueError(f"Failed to re-connect MCP server: {e}") from e - - def validate_server_url_change( - self, *, tenant_id: str, provider_id: str, new_server_url: str + @staticmethod + def validate_server_url_standalone( + *, + tenant_id: str, + new_server_url: str, + validation_data: ProviderUrlValidationData, ) -> ServerUrlValidationResult: """ Validate server URL change by attempting to connect to the new server. - This method should be called BEFORE update_provider to perform network operations - outside of the database transaction. + This method performs network operations and MUST be called OUTSIDE of any database session + to avoid holding locks during network I/O. + + Args: + tenant_id: Tenant ID for encryption + new_server_url: The new server URL to validate + validation_data: Provider data obtained from get_provider_for_url_validation Returns: ServerUrlValidationResult: Validation result with connection status and tools if successful @@ -579,25 +588,30 @@ class MCPToolManageService: return ServerUrlValidationResult(needs_validation=False) # Validate URL format - if not self._is_valid_url(new_server_url): + parsed = urlparse(new_server_url) + if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]: raise ValueError("Server URL is not valid.") # Always encrypt and hash the URL encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url) new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest() - # Get current provider - provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - # Check if URL is actually different - if new_server_url_hash == provider.server_url_hash: + if new_server_url_hash == validation_data.current_server_url_hash: # URL hasn't changed, but still return the encrypted data return ServerUrlValidationResult( - needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash + needs_validation=False, + encrypted_server_url=encrypted_server_url, + server_url_hash=new_server_url_hash, ) - # Perform validation by attempting to connect - reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider) + # Perform network validation - this is the expensive operation that should be outside session + reconnect_result = MCPToolManageService._reconnect_with_url( + server_url=new_server_url, + headers=validation_data.headers, + timeout=validation_data.timeout, + sse_read_timeout=validation_data.sse_read_timeout, + ) return ServerUrlValidationResult( needs_validation=True, validation_passed=True, @@ -606,6 +620,38 @@ class MCPToolManageService: server_url_hash=new_server_url_hash, ) + @staticmethod + def _reconnect_with_url( + *, + server_url: str, + headers: dict[str, str], + timeout: float | None, + sse_read_timeout: float | None, + ) -> ReconnectResult: + """ + Attempt to connect to MCP server with given URL. + This is a static method that performs network I/O without database access. + """ + from core.mcp.mcp_client import MCPClient + + try: + with MCPClient( + server_url=server_url, + headers=headers, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + ) as mcp_client: + tools = mcp_client.list_tools() + return ReconnectResult( + authed=True, + tools=json.dumps([tool.model_dump() for tool in tools]), + encrypted_credentials=EMPTY_CREDENTIALS_JSON, + ) + except MCPAuthError: + return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) + except MCPError as e: + raise ValueError(f"Failed to re-connect MCP server: {e}") from e + def _build_tool_provider_response( self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list ) -> ToolProviderApiEntity: diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 4c1f38c3bb..5fc2597c92 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -2,7 +2,6 @@ import logging import time import click -import sqlalchemy as sa from celery import shared_task from sqlalchemy import select @@ -12,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment -from models.source import DataSourceOauthBinding +from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -48,27 +47,36 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): page_id = data_source_info["notion_page_id"] page_type = data_source_info["type"] page_edited_time = data_source_info["last_edited_time"] + credential_id = data_source_info.get("credential_id") - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .where( - sa.and_( - DataSourceOauthBinding.tenant_id == document.tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', - ) - ) - .first() + # Get credentials from datasource provider + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_datasource_credentials( + tenant_id=document.tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", ) - if not data_source_binding: - raise ValueError("Data source binding not found.") + + if not credential: + logger.error( + "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", + document_id, + document.tenant_id, + credential_id, + ) + document.indexing_status = "error" + document.error = "Datasource credential not found. Please reconnect your Notion workspace." + document.stopped_at = naive_utc_now() + db.session.commit() + db.session.close() + return loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, - notion_access_token=data_source_binding.access_token, + notion_access_token=credential.get("integration_secret"), tenant_id=document.tenant_id, ) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index e75258a2a2..d814da8ec7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -6,6 +6,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities import GraphInitParams +from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.http_request.node import HttpRequestNode from core.workflow.nodes.node_factory import DifyNodeFactory @@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) -def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock): - """Test: In custom authentication mode, when the api_key is empty, no header should be set.""" +def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): + """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" from core.workflow.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) + from core.workflow.nodes.http_request.exc import AuthorizationConfigError from core.workflow.nodes.http_request.executor import Executor from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock): ssl_verify=True, ) - # Create executor - executor = Executor( - node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool - ) - - # Get assembled headers - headers = executor._assembling_headers() - - # When api_key is empty, the custom header should NOT be set - assert "X-Custom-Auth" not in headers + # Create executor should raise AuthorizationConfigError + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), + variable_pool=variable_pool, + ) @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) @@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_authorization_with_empty_api_key(setup_http_mock): """ - Test that custom authorization doesn't set header when api_key is empty. - This test verifies the fix for issue #23554. + Test that custom authorization raises error when api_key is empty. + This test verifies the fix for issue #21830. """ + node = init_http_node( config={ "id": "1", @@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock): ) result = node._run() - assert result.process_data is not None - data = result.process_data.get("request", "") - - # Custom header should NOT be set when api_key is empty - assert "X-Custom-Auth:" not in data + # Should fail with AuthorizationConfigError + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "API key is required" in result.error + assert result.error_type == "AuthorizationConfigError" @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 0871467a05..2ff71ea6ea 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -2,7 +2,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import TypeAdapter, ValidationError +from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -298,7 +300,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -364,7 +366,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none"} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -428,21 +430,10 @@ class TestApiToolManageService: labels = ["test"] # Act & Assert: Try to create provider with invalid schema type - with pytest.raises(ValueError) as exc_info: - ApiToolManageService.create_api_tool_provider( - user_id=account.id, - tenant_id=tenant.id, - provider_name=provider_name, - icon=icon, - credentials=credentials, - schema_type=schema_type, - schema=schema, - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - labels=labels, - ) + with pytest.raises(ValidationError) as exc_info: + TypeAdapter(ApiProviderSchemaType).validate_python(schema_type) - assert "invalid schema type" in str(exc_info.value) + assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies @@ -464,7 +455,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {} # Missing auth_type - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -507,7 +498,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔑"} credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 8c190762cf..6cae83ac37 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -1308,18 +1308,17 @@ class TestMCPToolManageService: type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(), ] - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: # Setup mock client mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1337,8 +1336,12 @@ class TestMCPToolManageService: assert tools_data[1]["name"] == "test_tool_2" # Verify mock interactions - provider_entity = mcp_provider.to_entity() - mock_mcp_client.assert_called_once() + mock_mcp_client.assert_called_once_with( + server_url="https://example.com/mcp", + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, + ) def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -1361,19 +1364,18 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise authentication error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: from core.mcp.error import MCPAuthError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1404,18 +1406,17 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise connection error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: from core.mcp.error import MCPError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"): - service._reconnect_provider( + MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py new file mode 100644 index 0000000000..40f58c9ddf --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -0,0 +1,420 @@ +from types import SimpleNamespace +from unittest.mock import ANY, Mock, patch + +import pytest + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueuePingEvent, +) +from core.app.entities.task_entities import ( + EasyUITaskState, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageFileStreamResponse, + MessageReplaceStreamResponse, + MessageStreamResponse, + PingStreamResponse, + StreamEvent, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AppGeneratorTTSPublisher +from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from core.model_runtime.entities.message_entities import TextPromptMessageContent +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse: + """Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock(spec=ChatAppGenerateEntity) + entity.task_id = "test-task-id" + entity.app_id = "test-app-id" + # minimal app_config used by pipeline internals + entity.app_config = SimpleNamespace( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.CHAT, + app_model_config_dict={}, + additional_features=None, + sensitive_word_avoidance=None, + ) + # minimal model_conf for LLMResult init + entity.model_conf = SimpleNamespace( + model="test-model", + provider_model_bundle=SimpleNamespace(model_type_instance=Mock()), + credentials={}, + ) + return entity + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = Mock(spec=AppQueueManager) + return manager + + @pytest.fixture + def mock_message_cycle_manager(self): + """Create a mock message cycle manager.""" + manager = Mock() + manager.get_message_event_type.return_value = StreamEvent.MESSAGE + manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse) + manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse) + manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse) + manager.handle_retriever_resources = Mock() + manager.handle_annotation_reply.return_value = None + return manager + + @pytest.fixture + def mock_conversation(self): + """Create a mock conversation.""" + conversation = Mock() + conversation.id = "test-conversation-id" + conversation.mode = "chat" + return conversation + + @pytest.fixture + def mock_message(self): + """Create a mock message.""" + message = Mock() + message.id = "test-message-id" + message.created_at = Mock() + message.created_at.timestamp.return_value = 1234567890 + return message + + @pytest.fixture + def mock_task_state(self): + """Create a mock task state.""" + task_state = Mock(spec=EasyUITaskState) + + # Create LLM result mock + llm_result = Mock(spec=RuntimeLLMResult) + llm_result.prompt_messages = [] + llm_result.message = Mock() + llm_result.message.content = "" + + task_state.llm_result = llm_result + task_state.answer = "" + + return task_state + + @pytest.fixture + def pipeline( + self, + mock_application_generate_entity, + mock_queue_manager, + mock_conversation, + mock_message, + mock_message_cycle_manager, + mock_task_state, + ): + """Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies.""" + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state + ): + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=mock_application_generate_entity, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + stream=True, + ) + pipeline._message_cycle_manager = mock_message_cycle_manager + pipeline._task_state = mock_task_state + return pipeline + + def test_get_message_event_type_called_once_when_first_llm_chunk_arrives( + self, pipeline, mock_message_cycle_manager + ): + """Expect get_message_event_type to be called when processing the first LLM chunk event.""" + # Setup a minimal LLM chunk event + chunk = Mock() + chunk.delta.message.content = "hi" + chunk.prompt_messages = [] + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + # Execute + list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id") + + def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of LLM chunk events with text content.""" + # Setup + chunk = Mock() + chunk.delta.message.content = "Hello, world!" + chunk.prompt_messages = [] + + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + mock_message_cycle_manager.message_to_stream_response.assert_called_once_with( + answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + assert mock_task_state.llm_result.message.content == "Hello, world!" + + def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of LLM chunk events with list content.""" + # Setup + text_content = Mock(spec=TextPromptMessageContent) + text_content.data = "Hello" + + chunk = Mock() + chunk.delta.message.content = [text_content, " world!"] + chunk.prompt_messages = [] + + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + mock_message_cycle_manager.message_to_stream_response.assert_called_once_with( + answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + assert mock_task_state.llm_result.message.content == "Hello world!" + + def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of agent message events.""" + # Setup + chunk = Mock() + chunk.delta.message.content = "Agent response" + + agent_message_event = Mock(spec=QueueAgentMessageEvent) + agent_message_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = agent_message_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + # Ensure method under assertion is a mock to track calls + pipeline._agent_message_to_stream_response = Mock(return_value=Mock()) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + # Agent messages should use _agent_message_to_stream_response + pipeline._agent_message_to_stream_response.assert_called_once_with( + answer="Agent response", message_id="test-message-id" + ) + + def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of message end events.""" + # Setup + llm_result = Mock(spec=RuntimeLLMResult) + llm_result.message = Mock() + llm_result.message.content = "Final response" + + message_end_event = Mock(spec=QueueMessageEndEvent) + message_end_event.llm_result = llm_result + + mock_queue_message = Mock() + mock_queue_message.event = message_end_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline._save_message = Mock() + pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + assert mock_task_state.llm_result == llm_result + pipeline._save_message.assert_called_once() + pipeline._message_end_to_stream_response.assert_called_once() + + def test_error_event(self, pipeline): + """Test handling of error events.""" + # Setup + error_event = Mock(spec=QueueErrorEvent) + error_event.error = Exception("Test error") + + mock_queue_message = Mock() + mock_queue_message.event = error_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.handle_error = Mock(return_value=Exception("Test error")) + pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + pipeline.handle_error.assert_called_once() + pipeline.error_to_stream_response.assert_called_once() + + def test_ping_event(self, pipeline): + """Test handling of ping events.""" + # Setup + ping_event = Mock(spec=QueuePingEvent) + + mock_queue_message = Mock() + mock_queue_message.event = ping_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + pipeline.ping_stream_response.assert_called_once() + + def test_file_event(self, pipeline, mock_message_cycle_manager): + """Test handling of file events.""" + # Setup + file_event = Mock(spec=QueueMessageFileEvent) + file_event.message_file_id = "file-id" + + mock_queue_message = Mock() + mock_queue_message.event = file_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + file_response = Mock(spec=MessageFileStreamResponse) + mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + assert responses[0] == file_response + mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event) + + def test_publisher_is_called_with_messages(self, pipeline): + """Test that publisher publishes messages when provided.""" + # Setup + publisher = Mock(spec=AppGeneratorTTSPublisher) + + ping_event = Mock(spec=QueuePingEvent) + mock_queue_message = Mock() + mock_queue_message.event = ping_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + list(pipeline._process_stream_response(publisher=publisher, trace_manager=None)) + + # Assert + # Called once with message and once with None at the end + assert publisher.publish.call_count == 2 + publisher.publish.assert_any_call(mock_queue_message) + publisher.publish.assert_any_call(None) + + def test_trace_manager_passed_to_save_message(self, pipeline): + """Test that trace manager is passed to _save_message.""" + # Setup + trace_manager = Mock(spec=TraceQueueManager) + + message_end_event = Mock(spec=QueueMessageEndEvent) + message_end_event.llm_result = None + + mock_queue_message = Mock() + mock_queue_message.event = message_end_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline._save_message = Mock() + pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager)) + + # Assert + pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager) + + def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling multiple events in sequence.""" + # Setup + chunk1 = Mock() + chunk1.delta.message.content = "Hello" + chunk1.prompt_messages = [] + + chunk2 = Mock() + chunk2.delta.message.content = " world!" + chunk2.prompt_messages = [] + + llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event1.chunk = chunk1 + + ping_event = Mock(spec=QueuePingEvent) + + llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event2.chunk = chunk2 + + mock_queue_messages = [ + Mock(event=llm_chunk_event1), + Mock(event=ping_event), + Mock(event=llm_chunk_event2), + ] + pipeline.queue_manager.listen.return_value = mock_queue_messages + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 3 + assert mock_task_state.llm_result.message.content == "Hello world!" + + # Verify calls to message_to_stream_response + assert mock_message_cycle_manager.message_to_stream_response.call_count == 2 + mock_message_cycle_manager.message_to_stream_response.assert_any_call( + answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + mock_message_cycle_manager.message_to_stream_response.assert_any_call( + answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py new file mode 100644 index 0000000000..5ef7f0d7f4 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -0,0 +1,166 @@ +"""Unit tests for the message cycle manager optimization.""" + +from types import SimpleNamespace +from unittest.mock import ANY, Mock, patch + +import pytest +from flask import current_app + +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager + + +class TestMessageCycleManagerOptimization: + """Test cases for the message cycle manager optimization that prevents N+1 queries.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock() + entity.task_id = "test-task-id" + return entity + + @pytest.fixture + def message_cycle_manager(self, mock_application_generate_entity): + """Create a message cycle manager instance.""" + task_state = Mock() + return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state) + + def test_get_message_event_type_with_message_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE_FILE when message has files.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_message_file = Mock() + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = mock_message_file + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE_FILE + mock_session.query.return_value.scalar.assert_called_once() + + def test_get_message_event_type_without_message_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE when message has no files.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and no message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = None + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE + mock_session.query.return_value.scalar.assert_called_once() + + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): + """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_message_file = Mock() + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = mock_message_file + + # Execute: compute event type once, then pass to message_to_stream_response + with current_app.app_context(): + event_type = message_cycle_manager.get_message_event_type("test-message-id") + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", message_id="test-message-id", event_type=event_type + ) + + # Assert + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.event == StreamEvent.MESSAGE_FILE + mock_session.query.return_value.scalar.assert_called_once() + + def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): + """Test that message_to_stream_response skips database query when event_type is provided.""" + with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + # Execute with event_type provided + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + + # Assert + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.event == StreamEvent.MESSAGE + # Should not query database when event_type is provided + mock_session_class.assert_not_called() + + def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): + """Test message_to_stream_response with from_variable_selector parameter.""" + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", + message_id="test-message-id", + from_variable_selector=["var1", "var2"], + event_type=StreamEvent.MESSAGE, + ) + + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.from_variable_selector == ["var1", "var2"] + assert result.event == StreamEvent.MESSAGE + + def test_optimization_usage_example(self, message_cycle_manager): + """Test the optimization pattern that should be used by callers.""" + # Step 1: Get event type once (this queries database) + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = None # No files + with current_app.app_context(): + event_type = message_cycle_manager.get_message_event_type("test-message-id") + + # Should query database once + mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) + assert event_type == StreamEvent.MESSAGE + + # Step 2: Use event_type for multiple calls (no additional queries) + with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + mock_session_class.return_value.__enter__.return_value = Mock() + + chunk1_response = message_cycle_manager.message_to_stream_response( + answer="Chunk 1", message_id="test-message-id", event_type=event_type + ) + + chunk2_response = message_cycle_manager.message_to_stream_response( + answer="Chunk 2", message_id="test-message-id", event_type=event_type + ) + + # Should not query database again + mock_session_class.assert_not_called() + + assert chunk1_response.event == StreamEvent.MESSAGE + assert chunk2_response.event == StreamEvent.MESSAGE + assert chunk1_response.answer == "Chunk 1" + assert chunk2_response.answer == "Chunk 2" diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index fd0b0e2e44..3203aab8c3 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch): # DB interactions should be recorded assert len(db_stub.session.added) == 2 assert db_stub.session.committed is True + + +def test_extract_images_from_docx_uses_internal_files_url(): + """Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access.""" + # Test the URL generation logic directly + from configs import dify_config + + # Mock the configuration values + original_files_url = getattr(dify_config, "FILES_URL", None) + original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None) + + try: + # Set both URLs - INTERNAL should take precedence + dify_config.FILES_URL = "http://external.example.com" + dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001" + + # Test the URL generation logic (same as in word_extractor.py) + upload_file_id = "test_file_id" + + # This is the pattern we fixed in the word extractor + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + generated_url = f"{base_url}/files/{upload_file_id}/file-preview" + + # Verify that INTERNAL_FILES_URL is used instead of FILES_URL + assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}" + assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}" + + finally: + # Restore original values + if original_files_url is not None: + dify_config.FILES_URL = original_files_url + if original_internal_files_url is not None: + dify_config.INTERNAL_FILES_URL = original_internal_files_url diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 7d246ac3cc..943a9e5712 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter: # Verify no empty chunks assert all(len(chunk) > 0 for chunk in result) + def test_double_slash_n(self): + data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2." + separator = "\\n\\n---\\n\\n" + splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator) + chunks = splitter.split_text(data) + assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."] + # ============================================================================ # Test Metadata Preservation diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index f040a92b6f..27df938102 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,3 +1,5 @@ +import pytest + from core.workflow.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, @@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import ( HttpRequestNodeData, ) from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout +from core.workflow.nodes.http_request.exc import AuthorizationConfigError from core.workflow.nodes.http_request.executor import Executor from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -348,3 +351,127 @@ def test_init_params(): executor = create_executor("key1:value1\n\nkey2:value2\n\n") executor._init_params() assert executor.params == [("key1", "value1"), ("key2", "value2")] + + +def test_empty_api_key_raises_error_bearer(): + """Test that empty API key raises AuthorizationConfigError for bearer auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": ""}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_empty_api_key_raises_error_basic(): + """Test that empty API key raises AuthorizationConfigError for basic auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "basic", "api_key": ""}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_empty_api_key_raises_error_custom(): + """Test that empty API key raises AuthorizationConfigError for custom auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_whitespace_only_api_key_raises_error(): + """Test that whitespace-only API key raises AuthorizationConfigError.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": " "}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_valid_api_key_works(): + """Test that valid API key works correctly for bearer auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": "valid-api-key-123"}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + executor = Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + # Should not raise an error + headers = executor._assembling_headers() + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer valid-api-key-123" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 83799c9508..539e72edb5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -1,3 +1,4 @@ +import json import time import pytest @@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables): def test_json_object_valid_schema(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age"], + } + ) variables = [ VariableEntity( @@ -65,7 +68,7 @@ def test_json_object_valid_schema(): ) ] - user_inputs = {"profile": {"age": 20, "name": "Tom"}} + user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})} node = make_start_node(user_inputs, variables) result = node._run() @@ -74,12 +77,23 @@ def test_json_object_valid_schema(): def test_json_object_invalid_json_string(): + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( variable="profile", label="profile", type=VariableEntityType.JSON_OBJECT, required=True, + json_schema=schema, ) ] @@ -88,38 +102,21 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="profile must be a JSON object"): - node._run() - - -@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"]) -def test_json_object_valid_json_but_not_object(value): - variables = [ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ] - - user_inputs = {"profile": value} - - node = make_start_node(user_inputs, variables) - - with pytest.raises(ValueError, match="profile must be a JSON object"): + with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'): node._run() def test_json_object_does_not_match_schema(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age", "name"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( @@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema(): ] # age is a string, which violates the schema (expects number) - user_inputs = {"profile": {"age": "twenty", "name": "Tom"}} + user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})} node = make_start_node(user_inputs, variables) @@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema(): def test_json_object_missing_required_schema_field(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age", "name"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( @@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field(): ] # Missing required field "name" - user_inputs = {"profile": {"age": 20}} + user_inputs = {"profile": json.dumps({"age": 20})} node = make_start_node(user_inputs, variables) @@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided(): variable="profile", label="profile", type=VariableEntityType.JSON_OBJECT, - required=False, + required=True, ) ] @@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided(): node = make_start_node(user_inputs, variables) # Current implementation raises a validation error even when the variable is optional - with pytest.raises(ValueError, match="profile must be a JSON object"): + with pytest.raises(ValueError, match="profile is required in input form"): node._run() diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 915aee3fa7..f50f744a75 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases: assert "Only team owner or team admin can perform this action" in str(exc_info.value) +class TestBillingServiceSubscriptionOperations: + """Unit tests for subscription operations in BillingService. + + Tests cover: + - Bulk plan retrieval with chunking + - Expired subscription cleanup whitelist retrieval + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_get_plan_bulk_with_empty_list(self, mock_send_request): + """Test bulk plan retrieval with empty tenant list.""" + # Arrange + tenant_ids = [] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + mock_send_request.assert_not_called() + + def test_get_plan_bulk_with_chunking(self, mock_send_request): + """Test bulk plan retrieval with more than 200 tenants (chunking logic).""" + # Arrange - 250 tenants to test chunking (chunk_size = 200) + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk: tenants 0-199 + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk: tenants 200-249 + second_chunk_response = { + "data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)} + } + + mock_send_request.side_effect = [first_chunk_response, second_chunk_response] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 250 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert result["tenant-200"]["plan"] == "professional" + assert result["tenant-249"]["plan"] == "professional" + assert mock_send_request.call_count == 2 + + # Verify first chunk call + first_call = mock_send_request.call_args_list[0] + assert first_call[0][0] == "POST" + assert first_call[0][1] == "/subscription/plan/batch" + assert len(first_call[1]["json"]["tenant_ids"]) == 200 + + # Verify second chunk call + second_call = mock_send_request.call_args_list[1] + assert len(second_call[1]["json"]["tenant_ids"]) == 50 + + def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request): + """Test bulk plan retrieval when one batch fails but others succeed.""" + # Arrange - 250 tenants, second batch will fail + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk succeeds + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk fails - need to create a mock that raises when called + def side_effect_func(*args, **kwargs): + if mock_send_request.call_count == 1: + return first_chunk_response + else: + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should only have data from first batch + assert len(result) == 200 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert "tenant-200" not in result + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request): + """Test bulk plan retrieval when all batches fail.""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # All chunks fail + def side_effect_func(*args, **kwargs): + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should return empty dict + assert result == {} + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request): + """Test bulk plan retrieval with exactly 200 tenants (boundary condition).""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(200)] + mock_send_request.return_value = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 200 + assert mock_send_request.call_count == 1 + + def test_get_plan_bulk_with_empty_data_response(self, mock_send_request): + """Test bulk plan retrieval with empty data in response.""" + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + mock_send_request.return_value = {"data": {}} + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + + def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request): + """Test successful retrieval of expired subscription cleanup whitelist.""" + # Arrange + api_response = [ + { + "created_at": "2025-10-16T01:56:17", + "tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", + "contact": "example@dify.ai", + "id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5", + "expired_at": "2026-01-01T01:56:17", + "updated_at": "2025-10-16T01:56:17", + }, + { + "created_at": "2025-10-16T02:00:00", + "tenant_id": "tenant-2", + "contact": "test@example.com", + "id": "whitelist-id-2", + "expired_at": "2026-02-01T00:00:00", + "updated_at": "2025-10-16T02:00:00", + }, + { + "created_at": "2025-10-16T03:00:00", + "tenant_id": "tenant-3", + "contact": "another@example.com", + "id": "whitelist-id-3", + "expired_at": "2026-03-01T00:00:00", + "updated_at": "2025-10-16T03:00:00", + }, + ] + mock_send_request.return_value = {"data": api_response} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert - should return only tenant_ids + assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"] + assert len(result) == 3 + assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6" + assert result[1] == "tenant-2" + assert result[2] == "tenant-3" + mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist") + + def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request): + """Test retrieval of empty cleanup whitelist.""" + # Arrange + mock_send_request.return_value = {"data": []} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert + assert result == [] + assert len(result) == 0 + + class TestBillingServiceIntegrationScenarios: """Integration-style tests simulating real-world usage scenarios. diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py new file mode 100644 index 0000000000..374abe0368 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -0,0 +1,520 @@ +""" +Unit tests for document indexing sync task. + +This module tests the document indexing sync task functionality including: +- Syncing Notion documents when updated +- Validating document and data source existence +- Credential validation and retrieval +- Cleaning old segments before re-indexing +- Error handling and edge cases +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from models.dataset import Dataset, Document, DocumentSegment +from tasks.document_indexing_sync_task import document_indexing_sync_task + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def document_id(): + """Generate a unique document ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def notion_workspace_id(): + """Generate a Notion workspace ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def notion_page_id(): + """Generate a Notion page ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def credential_id(): + """Generate a credential ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id): + """Create a mock Document object with Notion data source.""" + doc = Mock(spec=Document) + doc.id = document_id + doc.dataset_id = dataset_id + doc.tenant_id = tenant_id + doc.data_source_type = "notion_import" + doc.indexing_status = "completed" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + doc.doc_form = "text_model" + doc.data_source_info_dict = { + "notion_workspace_id": notion_workspace_id, + "notion_page_id": notion_page_id, + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + "credential_id": credential_id, + } + return doc + + +@pytest.fixture +def mock_document_segments(document_id): + """Create mock DocumentSegment objects.""" + segments = [] + for i in range(3): + segment = Mock(spec=DocumentSegment) + segment.id = str(uuid.uuid4()) + segment.document_id = document_id + segment.index_node_id = f"node-{document_id}-{i}" + segments.append(segment) + return segments + + +@pytest.fixture +def mock_db_session(): + """Mock database session.""" + with patch("tasks.document_indexing_sync_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_session.scalars.return_value = MagicMock() + yield mock_session + + +@pytest.fixture +def mock_datasource_provider_service(): + """Mock DatasourceProviderService.""" + with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class: + mock_service = MagicMock() + mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"} + mock_service_class.return_value = mock_service + yield mock_service + + +@pytest.fixture +def mock_notion_extractor(): + """Mock NotionExtractor.""" + with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class: + mock_extractor = MagicMock() + mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time + mock_extractor_class.return_value = mock_extractor + yield mock_extractor + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean = Mock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + yield mock_factory + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner.""" + with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock(spec=IndexingRunner) + mock_runner.run = Mock() + mock_runner_class.return_value = mock_runner + yield mock_runner + + +# ============================================================================ +# Tests for document_indexing_sync_task +# ============================================================================ + + +class TestDocumentIndexingSyncTask: + """Tests for the document_indexing_sync_task function.""" + + def test_document_not_found(self, mock_db_session, dataset_id, document_id): + """Test that task handles document not found gracefully.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_db_session.close.assert_called_once() + + def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when notion_workspace_id is missing.""" + # Arrange + mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"} + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when notion_page_id is missing.""" + # Arrange + mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"} + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when data_source_info is empty.""" + # Arrange + mock_document.data_source_info_dict = None + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_credential_not_found( + self, + mock_db_session, + mock_datasource_provider_service, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles missing credentials by updating document status.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_datasource_provider_service.get_datasource_credentials.return_value = None + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + assert mock_document.indexing_status == "error" + assert "Datasource credential not found" in mock_document.error + assert mock_document.stopped_at is not None + mock_db_session.commit.assert_called() + mock_db_session.close.assert_called() + + def test_page_not_updated( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task does nothing when page has not been updated.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + # Return same time as stored in document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Document status should remain unchanged + assert mock_document.indexing_status == "completed" + # No session operations should be performed beyond the initial query + mock_db_session.close.assert_not_called() + + def test_successful_sync_when_page_updated( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test successful sync flow when Notion page has been updated.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + # NotionExtractor returns updated time + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Verify document status was updated to parsing + assert mock_document.indexing_status == "parsing" + assert mock_document.processing_started_at is not None + + # Verify segments were cleaned + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + mock_processor.clean.assert_called_once() + + # Verify segments were deleted from database + for segment in mock_document_segments: + mock_db_session.delete.assert_any_call(segment) + + # Verify indexing runner was called + mock_indexing_runner.run.assert_called_once_with([mock_document]) + + # Verify session operations + assert mock_db_session.commit.called + mock_db_session.close.assert_called_once() + + def test_dataset_not_found_during_cleaning( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles dataset not found during cleaning phase.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None] + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Document should still be set to parsing + assert mock_document.indexing_status == "parsing" + # Session should be closed after error + mock_db_session.close.assert_called_once() + + def test_cleaning_error_continues_to_indexing( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + dataset_id, + document_id, + ): + """Test that indexing continues even if cleaning fails.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error") + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Indexing should still be attempted despite cleaning error + mock_indexing_runner.run.assert_called_once_with([mock_document]) + mock_db_session.close.assert_called_once() + + def test_indexing_runner_document_paused_error( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that DocumentIsPausedError is handled gracefully.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Session should be closed after handling error + mock_db_session.close.assert_called_once() + + def test_indexing_runner_general_error( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that general exceptions during indexing are handled.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + mock_indexing_runner.run.side_effect = Exception("Indexing error") + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Session should be closed after error + mock_db_session.close.assert_called_once() + + def test_notion_extractor_initialized_with_correct_params( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + notion_workspace_id, + notion_page_id, + ): + """Test that NotionExtractor is initialized with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update + + # Act + with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class: + mock_extractor = MagicMock() + mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + mock_extractor_class.return_value = mock_extractor + + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_extractor_class.assert_called_once_with( + notion_workspace_id=notion_workspace_id, + notion_obj_id=notion_page_id, + notion_page_type="page", + notion_access_token="test_token", + tenant_id=mock_document.tenant_id, + ) + + def test_datasource_credentials_requested_correctly( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + credential_id, + ): + """Test that datasource credentials are requested with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with( + tenant_id=mock_document.tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) + + def test_credential_id_missing_uses_none( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles missing credential_id by passing None.""" + # Arrange + mock_document.data_source_info_dict = { + "notion_workspace_id": "ws123", + "notion_page_id": "page123", + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + } + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with( + tenant_id=mock_document.tenant_id, + credential_id=None, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) + + def test_index_processor_clean_called_with_correct_params( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that index processor clean is called with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + expected_node_ids = [seg.index_node_id for seg in mock_document_segments] + mock_processor.clean.assert_called_once_with( + mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True + ) diff --git a/api/uv.lock b/api/uv.lock index 726abf6920..8d0dffbd8f 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1380,7 +1380,7 @@ dependencies = [ { name = "bs4" }, { name = "cachetools" }, { name = "celery" }, - { name = "chardet" }, + { name = "charset-normalizer" }, { name = "croniter" }, { name = "flask" }, { name = "flask-compress" }, @@ -1403,6 +1403,7 @@ dependencies = [ { name = "httpx-sse" }, { name = "jieba" }, { name = "json-repair" }, + { name = "jsonschema" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1577,7 +1578,7 @@ requires-dist = [ { name = "bs4", specifier = "~=0.0.1" }, { name = "cachetools", specifier = "~=5.3.0" }, { name = "celery", specifier = "~=5.5.2" }, - { name = "chardet", specifier = "~=5.1.0" }, + { name = "charset-normalizer", specifier = ">=3.4.4" }, { name = "croniter", specifier = ">=6.0.0" }, { name = "flask", specifier = "~=3.1.2" }, { name = "flask-compress", specifier = ">=1.17,<1.18" }, @@ -1600,6 +1601,7 @@ requires-dist = [ { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.41.1" }, + { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "litellm", specifier = "==1.77.1" }, diff --git a/dev/start-worker b/dev/start-worker index a01da11d86..7876620188 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -37,6 +37,7 @@ show_help() { echo " pipeline - Standard pipeline tasks" echo " triggered_workflow_dispatcher - Trigger dispatcher tasks" echo " trigger_refresh_executor - Trigger refresh tasks" + echo " retention - Retention tasks" } # Parse command line arguments @@ -105,10 +106,10 @@ if [[ -z "${QUEUES}" ]]; then # Configure queues based on edition if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset and workflow have separate queues - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi echo "No queues specified, using edition-based defaults: ${QUEUES}" diff --git a/docker/.env.example b/docker/.env.example index dd0d083da3..e5cdb64dae 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1369,7 +1369,10 @@ PLUGIN_STDIO_BUFFER_SIZE=1024 PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 +# Plugin Daemon side timeout (configure to match the API side below) PLUGIN_MAX_EXECUTION_TIMEOUT=600 +# API side timeout (configure to match the Plugin Daemon side above) +PLUGIN_DAEMON_TIMEOUT=600.0 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple PIP_MIRROR_URL= @@ -1479,4 +1482,9 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 ANNOTATION_IMPORT_MAX_CONCURRENT=5 # The API key of amplitude -AMPLITUDE_API_KEY= \ No newline at end of file +AMPLITUDE_API_KEY= + +# Sandbox expired records clean configuration +SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 +SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 +SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 4f6194b9e4..a07ed9e8ad 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -34,6 +34,7 @@ services: PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost} PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: init_permissions: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index aca4325880..24e1077ebe 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -591,6 +591,7 @@ x-shared-env: &shared-api-worker-env PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} @@ -663,6 +664,9 @@ x-shared-env: &shared-api-worker-env ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20} ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5} AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} + SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} + SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} + SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} services: # Init container to fix permissions @@ -699,6 +703,7 @@ services: PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost} PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: init_permissions: diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md index 03f3221798..291c8dab40 100644 --- a/docs/fr-FR/README.md +++ b/docs/fr-FR/README.md @@ -61,14 +61,14 @@

langgenius%2Fdify | Trendshift

-Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales: +Dify est une plateforme de développement d'applications LLM open source. Sa interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:

**1. Flux de travail** : Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore. **2. Prise en charge complète des modèles** : -Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). +Intégration transparente avec des centaines de LLM propriétaires / open source offerts par dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) @@ -79,7 +79,7 @@ Interface intuitive pour créer des prompts, comparer les performances des modè Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants. **5. Capacités d'agent** : -Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. +Vous pouvez définir des agents basés sur l'appel de fonctions LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. **6. LLMOps** : Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations. diff --git a/web/.vscode/extensions.json b/web/.vscode/extensions.json index e0e72ce11e..68f5c7bf0e 100644 --- a/web/.vscode/extensions.json +++ b/web/.vscode/extensions.json @@ -1,7 +1,6 @@ { "recommendations": [ "bradlc.vscode-tailwindcss", - "firsttris.vscode-jest-runner", "kisstkondoros.vscode-codemetrics" ] } diff --git a/web/README.md b/web/README.md index 1855ebc3b8..7f5740a471 100644 --- a/web/README.md +++ b/web/README.md @@ -99,14 +99,14 @@ If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscod ## Test -We use [Jest](https://jestjs.io/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing. +We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing. **📖 Complete Testing Guide**: See [web/testing/testing.md](./testing/testing.md) for detailed testing specifications, best practices, and examples. Run test: ```bash -pnpm run test +pnpm test ``` ### Example Code diff --git a/web/__mocks__/mime.js b/web/__mocks__/mime.js deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts index 594fe38f14..05ced08ff6 100644 --- a/web/__mocks__/provider-context.ts +++ b/web/__mocks__/provider-context.ts @@ -1,9 +1,41 @@ import { merge, noop } from 'lodash-es' import { defaultPlan } from '@/app/components/billing/config' -import { baseProviderContextValue } from '@/context/provider-context' import type { ProviderContextState } from '@/context/provider-context' import type { Plan, UsagePlanInfo } from '@/app/components/billing/type' +// Avoid being mocked in tests +export const baseProviderContextValue: ProviderContextState = { + modelProviders: [], + refreshModelProviders: noop, + textGenerationModelList: [], + supportRetrievalMethods: [], + isAPIKeySet: true, + plan: defaultPlan, + isFetchedPlan: false, + enableBilling: false, + onPlanInfoChanged: noop, + enableReplaceWebAppLogo: false, + modelLoadBalancingEnabled: false, + datasetOperatorEnabled: false, + enableEducationPlan: false, + isEducationWorkspace: false, + isEducationAccount: false, + allowRefreshEducationVerify: false, + educationAccountExpireAt: null, + isLoadingEducationAccountInfo: false, + isFetchingEducationAccountInfo: false, + webappCopyrightEnabled: false, + licenseLimit: { + workspace_members: { + size: 0, + limit: 0, + }, + }, + refreshLicenseLimit: noop, + isAllowTransferWorkspace: false, + isAllowPublishAsCustomKnowledgePipelineTemplate: false, +} + export const createMockProviderContextValue = (overrides: Partial = {}): ProviderContextState => { const merged = merge({}, baseProviderContextValue, overrides) diff --git a/web/__mocks__/react-i18next.ts b/web/__mocks__/react-i18next.ts deleted file mode 100644 index 1e3f58927e..0000000000 --- a/web/__mocks__/react-i18next.ts +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Shared mock for react-i18next - * - * Jest automatically uses this mock when react-i18next is imported in tests. - * The default behavior returns the translation key as-is, which is suitable - * for most test scenarios. - * - * For tests that need custom translations, you can override with jest.mock(): - * - * @example - * jest.mock('react-i18next', () => ({ - * useTranslation: () => ({ - * t: (key: string) => { - * if (key === 'some.key') return 'Custom translation' - * return key - * }, - * }), - * })) - */ - -export const useTranslation = () => ({ - t: (key: string, options?: Record) => { - if (options?.returnObjects) - return [`${key}-feature-1`, `${key}-feature-2`] - if (options) - return `${key}:${JSON.stringify(options)}` - return key - }, - i18n: { - language: 'en', - changeLanguage: jest.fn(), - }, -}) - -export const Trans = ({ children }: { children?: React.ReactNode }) => children - -export const initReactI18next = { - type: '3rdParty', - init: jest.fn(), -} diff --git a/web/__tests__/document-detail-navigation-fix.test.tsx b/web/__tests__/document-detail-navigation-fix.test.tsx index a358744998..21673554e5 100644 --- a/web/__tests__/document-detail-navigation-fix.test.tsx +++ b/web/__tests__/document-detail-navigation-fix.test.tsx @@ -1,3 +1,4 @@ +import type { Mock } from 'vitest' /** * Document Detail Navigation Fix Verification Test * @@ -10,32 +11,32 @@ import { useRouter } from 'next/navigation' import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document' // Mock Next.js router -const mockPush = jest.fn() -jest.mock('next/navigation', () => ({ - useRouter: jest.fn(() => ({ +const mockPush = vi.fn() +vi.mock('next/navigation', () => ({ + useRouter: vi.fn(() => ({ push: mockPush, })), })) // Mock the document service hooks -jest.mock('@/service/knowledge/use-document', () => ({ - useDocumentDetail: jest.fn(), - useDocumentMetadata: jest.fn(), - useInvalidDocumentList: jest.fn(() => jest.fn()), +vi.mock('@/service/knowledge/use-document', () => ({ + useDocumentDetail: vi.fn(), + useDocumentMetadata: vi.fn(), + useInvalidDocumentList: vi.fn(() => vi.fn()), })) // Mock other dependencies -jest.mock('@/context/dataset-detail', () => ({ - useDatasetDetailContext: jest.fn(() => [null]), +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContext: vi.fn(() => [null]), })) -jest.mock('@/service/use-base', () => ({ - useInvalid: jest.fn(() => jest.fn()), +vi.mock('@/service/use-base', () => ({ + useInvalid: vi.fn(() => vi.fn()), })) -jest.mock('@/service/knowledge/use-segment', () => ({ - useSegmentListKey: jest.fn(), - useChildSegmentListKey: jest.fn(), +vi.mock('@/service/knowledge/use-segment', () => ({ + useSegmentListKey: vi.fn(), + useChildSegmentListKey: vi.fn(), })) // Create a minimal version of the DocumentDetail component that includes our fix @@ -66,10 +67,10 @@ const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; d describe('Document Detail Navigation Fix Verification', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() // Mock successful API responses - ;(useDocumentDetail as jest.Mock).mockReturnValue({ + ;(useDocumentDetail as Mock).mockReturnValue({ data: { id: 'doc-123', name: 'Test Document', @@ -80,7 +81,7 @@ describe('Document Detail Navigation Fix Verification', () => { error: null, }) - ;(useDocumentMetadata as jest.Mock).mockReturnValue({ + ;(useDocumentMetadata as Mock).mockReturnValue({ data: null, error: null, }) diff --git a/web/__tests__/embedded-user-id-auth.test.tsx b/web/__tests__/embedded-user-id-auth.test.tsx index 9d6734b120..b49e3b7885 100644 --- a/web/__tests__/embedded-user-id-auth.test.tsx +++ b/web/__tests__/embedded-user-id-auth.test.tsx @@ -4,16 +4,17 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import MailAndPasswordAuth from '@/app/(shareLayout)/webapp-signin/components/mail-and-password-auth' import CheckCode from '@/app/(shareLayout)/webapp-signin/check-code/page' -const replaceMock = jest.fn() -const backMock = jest.fn() +const replaceMock = vi.fn() +const backMock = vi.fn() +const useSearchParamsMock = vi.fn(() => new URLSearchParams()) -jest.mock('next/navigation', () => ({ - usePathname: jest.fn(() => '/chatbot/test-app'), - useRouter: jest.fn(() => ({ +vi.mock('next/navigation', () => ({ + usePathname: vi.fn(() => '/chatbot/test-app'), + useRouter: vi.fn(() => ({ replace: replaceMock, back: backMock, })), - useSearchParams: jest.fn(), + useSearchParams: () => useSearchParamsMock(), })) const mockStoreState = { @@ -21,59 +22,55 @@ const mockStoreState = { shareCode: 'test-app', } -const useWebAppStoreMock = jest.fn((selector?: (state: typeof mockStoreState) => any) => { +const useWebAppStoreMock = vi.fn((selector?: (state: typeof mockStoreState) => any) => { return selector ? selector(mockStoreState) : mockStoreState }) -jest.mock('@/context/web-app-context', () => ({ +vi.mock('@/context/web-app-context', () => ({ useWebAppStore: (selector?: (state: typeof mockStoreState) => any) => useWebAppStoreMock(selector), })) -const webAppLoginMock = jest.fn() -const webAppEmailLoginWithCodeMock = jest.fn() -const sendWebAppEMailLoginCodeMock = jest.fn() +const webAppLoginMock = vi.fn() +const webAppEmailLoginWithCodeMock = vi.fn() +const sendWebAppEMailLoginCodeMock = vi.fn() -jest.mock('@/service/common', () => ({ +vi.mock('@/service/common', () => ({ webAppLogin: (...args: any[]) => webAppLoginMock(...args), webAppEmailLoginWithCode: (...args: any[]) => webAppEmailLoginWithCodeMock(...args), sendWebAppEMailLoginCode: (...args: any[]) => sendWebAppEMailLoginCodeMock(...args), })) -const fetchAccessTokenMock = jest.fn() +const fetchAccessTokenMock = vi.fn() -jest.mock('@/service/share', () => ({ +vi.mock('@/service/share', () => ({ fetchAccessToken: (...args: any[]) => fetchAccessTokenMock(...args), })) -const setWebAppAccessTokenMock = jest.fn() -const setWebAppPassportMock = jest.fn() +const setWebAppAccessTokenMock = vi.fn() +const setWebAppPassportMock = vi.fn() -jest.mock('@/service/webapp-auth', () => ({ +vi.mock('@/service/webapp-auth', () => ({ setWebAppAccessToken: (...args: any[]) => setWebAppAccessTokenMock(...args), setWebAppPassport: (...args: any[]) => setWebAppPassportMock(...args), - webAppLogout: jest.fn(), + webAppLogout: vi.fn(), })) -jest.mock('@/app/components/signin/countdown', () => () =>
) +vi.mock('@/app/components/signin/countdown', () => ({ default: () =>
})) -jest.mock('@remixicon/react', () => ({ +vi.mock('@remixicon/react', () => ({ RiMailSendFill: () =>
, RiArrowLeftLine: () =>
, })) -const { useSearchParams } = jest.requireMock('next/navigation') as { - useSearchParams: jest.Mock -} - beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) describe('embedded user id propagation in authentication flows', () => { it('passes embedded user id when logging in with email and password', async () => { const params = new URLSearchParams() params.set('redirect_url', encodeURIComponent('/chatbot/test-app')) - useSearchParams.mockReturnValue(params) + useSearchParamsMock.mockReturnValue(params) webAppLoginMock.mockResolvedValue({ result: 'success', data: { access_token: 'login-token' } }) fetchAccessTokenMock.mockResolvedValue({ access_token: 'passport-token' }) @@ -100,7 +97,7 @@ describe('embedded user id propagation in authentication flows', () => { params.set('redirect_url', encodeURIComponent('/chatbot/test-app')) params.set('email', encodeURIComponent('user@example.com')) params.set('token', encodeURIComponent('token-abc')) - useSearchParams.mockReturnValue(params) + useSearchParamsMock.mockReturnValue(params) webAppEmailLoginWithCodeMock.mockResolvedValue({ result: 'success', data: { access_token: 'code-token' } }) fetchAccessTokenMock.mockResolvedValue({ access_token: 'passport-token' }) diff --git a/web/__tests__/embedded-user-id-store.test.tsx b/web/__tests__/embedded-user-id-store.test.tsx index 24a815222e..c6d1400aef 100644 --- a/web/__tests__/embedded-user-id-store.test.tsx +++ b/web/__tests__/embedded-user-id-store.test.tsx @@ -1,42 +1,42 @@ import React from 'react' import { render, screen, waitFor } from '@testing-library/react' +import { AccessMode } from '@/models/access-control' import WebAppStoreProvider, { useWebAppStore } from '@/context/web-app-context' -jest.mock('next/navigation', () => ({ - usePathname: jest.fn(() => '/chatbot/sample-app'), - useSearchParams: jest.fn(() => { +vi.mock('next/navigation', () => ({ + usePathname: vi.fn(() => '/chatbot/sample-app'), + useSearchParams: vi.fn(() => { const params = new URLSearchParams() return params }), })) -jest.mock('@/service/use-share', () => { - const { AccessMode } = jest.requireActual('@/models/access-control') - return { - useGetWebAppAccessModeByCode: jest.fn(() => ({ - isLoading: false, - data: { accessMode: AccessMode.PUBLIC }, - })), - } -}) - -jest.mock('@/app/components/base/chat/utils', () => ({ - getProcessedSystemVariablesFromUrlParams: jest.fn(), +vi.mock('@/service/use-share', () => ({ + useGetWebAppAccessModeByCode: vi.fn(() => ({ + isLoading: false, + data: { accessMode: AccessMode.PUBLIC }, + })), })) -const { getProcessedSystemVariablesFromUrlParams: mockGetProcessedSystemVariablesFromUrlParams } - = jest.requireMock('@/app/components/base/chat/utils') as { - getProcessedSystemVariablesFromUrlParams: jest.Mock - } +// Store the mock implementation in a way that survives hoisting +const mockGetProcessedSystemVariablesFromUrlParams = vi.fn() -jest.mock('@/context/global-public-context', () => { - const mockGlobalStoreState = { +vi.mock('@/app/components/base/chat/utils', () => ({ + getProcessedSystemVariablesFromUrlParams: (...args: any[]) => mockGetProcessedSystemVariablesFromUrlParams(...args), +})) + +// Use vi.hoisted to define mock state before vi.mock hoisting +const { mockGlobalStoreState } = vi.hoisted(() => ({ + mockGlobalStoreState: { isGlobalPending: false, - setIsGlobalPending: jest.fn(), + setIsGlobalPending: vi.fn(), systemFeatures: {}, - setSystemFeatures: jest.fn(), - } + setSystemFeatures: vi.fn(), + }, +})) + +vi.mock('@/context/global-public-context', () => { const useGlobalPublicStore = Object.assign( (selector?: (state: typeof mockGlobalStoreState) => any) => selector ? selector(mockGlobalStoreState) : mockGlobalStoreState, @@ -56,21 +56,6 @@ jest.mock('@/context/global-public-context', () => { } }) -const { - useGlobalPublicStore: useGlobalPublicStoreMock, -} = jest.requireMock('@/context/global-public-context') as { - useGlobalPublicStore: ((selector?: (state: any) => any) => any) & { - setState: (updater: any) => void - __mockState: { - isGlobalPending: boolean - setIsGlobalPending: jest.Mock - systemFeatures: Record - setSystemFeatures: jest.Mock - } - } -} -const mockGlobalStoreState = useGlobalPublicStoreMock.__mockState - const TestConsumer = () => { const embeddedUserId = useWebAppStore(state => state.embeddedUserId) const embeddedConversationId = useWebAppStore(state => state.embeddedConversationId) diff --git a/web/__tests__/goto-anything/command-selector.test.tsx b/web/__tests__/goto-anything/command-selector.test.tsx index e502c533bb..df33ee645c 100644 --- a/web/__tests__/goto-anything/command-selector.test.tsx +++ b/web/__tests__/goto-anything/command-selector.test.tsx @@ -1,10 +1,9 @@ import React from 'react' import { fireEvent, render, screen } from '@testing-library/react' -import '@testing-library/jest-dom' import CommandSelector from '../../app/components/goto-anything/command-selector' import type { ActionItem } from '../../app/components/goto-anything/actions/types' -jest.mock('cmdk', () => ({ +vi.mock('cmdk', () => ({ Command: { Group: ({ children, className }: any) =>
{children}
, Item: ({ children, onSelect, value, className }: any) => ( @@ -27,36 +26,36 @@ describe('CommandSelector', () => { shortcut: '@app', title: 'Search Applications', description: 'Search apps', - search: jest.fn(), + search: vi.fn(), }, knowledge: { key: '@knowledge', shortcut: '@kb', title: 'Search Knowledge', description: 'Search knowledge bases', - search: jest.fn(), + search: vi.fn(), }, plugin: { key: '@plugin', shortcut: '@plugin', title: 'Search Plugins', description: 'Search plugins', - search: jest.fn(), + search: vi.fn(), }, node: { key: '@node', shortcut: '@node', title: 'Search Nodes', description: 'Search workflow nodes', - search: jest.fn(), + search: vi.fn(), }, } - const mockOnCommandSelect = jest.fn() - const mockOnCommandValueChange = jest.fn() + const mockOnCommandSelect = vi.fn() + const mockOnCommandValueChange = vi.fn() beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) describe('Basic Rendering', () => { diff --git a/web/__tests__/goto-anything/match-action.test.ts b/web/__tests__/goto-anything/match-action.test.ts index 3df9c0d533..2d1866a4b8 100644 --- a/web/__tests__/goto-anything/match-action.test.ts +++ b/web/__tests__/goto-anything/match-action.test.ts @@ -1,11 +1,12 @@ +import type { Mock } from 'vitest' import type { ActionItem } from '../../app/components/goto-anything/actions/types' // Mock the entire actions module to avoid import issues -jest.mock('../../app/components/goto-anything/actions', () => ({ - matchAction: jest.fn(), +vi.mock('../../app/components/goto-anything/actions', () => ({ + matchAction: vi.fn(), })) -jest.mock('../../app/components/goto-anything/actions/commands/registry') +vi.mock('../../app/components/goto-anything/actions/commands/registry') // Import after mocking to get mocked version import { matchAction } from '../../app/components/goto-anything/actions' @@ -39,7 +40,7 @@ const actualMatchAction = (query: string, actions: Record) = } // Replace mock with actual implementation -;(matchAction as jest.Mock).mockImplementation(actualMatchAction) +;(matchAction as Mock).mockImplementation(actualMatchAction) describe('matchAction Logic', () => { const mockActions: Record = { @@ -48,27 +49,27 @@ describe('matchAction Logic', () => { shortcut: '@a', title: 'Search Applications', description: 'Search apps', - search: jest.fn(), + search: vi.fn(), }, knowledge: { key: '@knowledge', shortcut: '@kb', title: 'Search Knowledge', description: 'Search knowledge bases', - search: jest.fn(), + search: vi.fn(), }, slash: { key: '/', shortcut: '/', title: 'Commands', description: 'Execute commands', - search: jest.fn(), + search: vi.fn(), }, } beforeEach(() => { - jest.clearAllMocks() - ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + vi.clearAllMocks() + ;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([ { name: 'docs', mode: 'direct' }, { name: 'community', mode: 'direct' }, { name: 'feedback', mode: 'direct' }, @@ -188,7 +189,7 @@ describe('matchAction Logic', () => { describe('Mode-based Filtering', () => { it('should filter direct mode commands from matching', () => { - ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + ;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([ { name: 'test', mode: 'direct' }, ]) @@ -197,7 +198,7 @@ describe('matchAction Logic', () => { }) it('should allow submenu mode commands to match', () => { - ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + ;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([ { name: 'test', mode: 'submenu' }, ]) @@ -206,7 +207,7 @@ describe('matchAction Logic', () => { }) it('should treat undefined mode as submenu', () => { - ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + ;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([ { name: 'test' }, // No mode specified ]) @@ -227,7 +228,7 @@ describe('matchAction Logic', () => { }) it('should handle empty command list', () => { - ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([]) + ;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([]) const result = matchAction('/anything', mockActions) expect(result).toBeUndefined() }) diff --git a/web/__tests__/goto-anything/scope-command-tags.test.tsx b/web/__tests__/goto-anything/scope-command-tags.test.tsx index 339e259a06..0e10019760 100644 --- a/web/__tests__/goto-anything/scope-command-tags.test.tsx +++ b/web/__tests__/goto-anything/scope-command-tags.test.tsx @@ -1,6 +1,5 @@ import React from 'react' import { render, screen } from '@testing-library/react' -import '@testing-library/jest-dom' // Type alias for search mode type SearchMode = 'scopes' | 'commands' | null diff --git a/web/__tests__/goto-anything/search-error-handling.test.ts b/web/__tests__/goto-anything/search-error-handling.test.ts index d2fd921e1c..69bd2487dd 100644 --- a/web/__tests__/goto-anything/search-error-handling.test.ts +++ b/web/__tests__/goto-anything/search-error-handling.test.ts @@ -1,3 +1,4 @@ +import type { MockedFunction } from 'vitest' /** * Test GotoAnything search error handling mechanisms * @@ -14,33 +15,33 @@ import { fetchAppList } from '@/service/apps' import { fetchDatasets } from '@/service/datasets' // Mock API functions -jest.mock('@/service/base', () => ({ - postMarketplace: jest.fn(), +vi.mock('@/service/base', () => ({ + postMarketplace: vi.fn(), })) -jest.mock('@/service/apps', () => ({ - fetchAppList: jest.fn(), +vi.mock('@/service/apps', () => ({ + fetchAppList: vi.fn(), })) -jest.mock('@/service/datasets', () => ({ - fetchDatasets: jest.fn(), +vi.mock('@/service/datasets', () => ({ + fetchDatasets: vi.fn(), })) -const mockPostMarketplace = postMarketplace as jest.MockedFunction -const mockFetchAppList = fetchAppList as jest.MockedFunction -const mockFetchDatasets = fetchDatasets as jest.MockedFunction +const mockPostMarketplace = postMarketplace as MockedFunction +const mockFetchAppList = fetchAppList as MockedFunction +const mockFetchDatasets = fetchDatasets as MockedFunction describe('GotoAnything Search Error Handling', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() // Suppress console.warn for clean test output - jest.spyOn(console, 'warn').mockImplementation(() => { + vi.spyOn(console, 'warn').mockImplementation(() => { // Suppress console.warn for clean test output }) }) afterEach(() => { - jest.restoreAllMocks() + vi.restoreAllMocks() }) describe('@plugin search error handling', () => { diff --git a/web/__tests__/goto-anything/slash-command-modes.test.tsx b/web/__tests__/goto-anything/slash-command-modes.test.tsx index f8126958fc..e8f3509083 100644 --- a/web/__tests__/goto-anything/slash-command-modes.test.tsx +++ b/web/__tests__/goto-anything/slash-command-modes.test.tsx @@ -1,17 +1,16 @@ -import '@testing-library/jest-dom' import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry' import type { SlashCommandHandler } from '../../app/components/goto-anything/actions/commands/types' // Mock the registry -jest.mock('../../app/components/goto-anything/actions/commands/registry') +vi.mock('../../app/components/goto-anything/actions/commands/registry') describe('Slash Command Dual-Mode System', () => { const mockDirectCommand: SlashCommandHandler = { name: 'docs', description: 'Open documentation', mode: 'direct', - execute: jest.fn(), - search: jest.fn().mockResolvedValue([ + execute: vi.fn(), + search: vi.fn().mockResolvedValue([ { id: 'docs', title: 'Documentation', @@ -20,15 +19,15 @@ describe('Slash Command Dual-Mode System', () => { data: { command: 'navigation.docs', args: {} }, }, ]), - register: jest.fn(), - unregister: jest.fn(), + register: vi.fn(), + unregister: vi.fn(), } const mockSubmenuCommand: SlashCommandHandler = { name: 'theme', description: 'Change theme', mode: 'submenu', - search: jest.fn().mockResolvedValue([ + search: vi.fn().mockResolvedValue([ { id: 'theme-light', title: 'Light Theme', @@ -44,18 +43,18 @@ describe('Slash Command Dual-Mode System', () => { data: { command: 'theme.set', args: { theme: 'dark' } }, }, ]), - register: jest.fn(), - unregister: jest.fn(), + register: vi.fn(), + unregister: vi.fn(), } beforeEach(() => { - jest.clearAllMocks() - ;(slashCommandRegistry as any).findCommand = jest.fn((name: string) => { + vi.clearAllMocks() + ;(slashCommandRegistry as any).findCommand = vi.fn((name: string) => { if (name === 'docs') return mockDirectCommand if (name === 'theme') return mockSubmenuCommand return null }) - ;(slashCommandRegistry as any).getAllCommands = jest.fn(() => [ + ;(slashCommandRegistry as any).getAllCommands = vi.fn(() => [ mockDirectCommand, mockSubmenuCommand, ]) @@ -63,8 +62,8 @@ describe('Slash Command Dual-Mode System', () => { describe('Direct Mode Commands', () => { it('should execute immediately when selected', () => { - const mockSetShow = jest.fn() - const mockSetSearchQuery = jest.fn() + const mockSetShow = vi.fn() + const mockSetSearchQuery = vi.fn() // Simulate command selection const handler = slashCommandRegistry.findCommand('docs') @@ -88,7 +87,7 @@ describe('Slash Command Dual-Mode System', () => { }) it('should close modal after execution', () => { - const mockModalClose = jest.fn() + const mockModalClose = vi.fn() const handler = slashCommandRegistry.findCommand('docs') if (handler?.mode === 'direct' && handler.execute) { @@ -118,7 +117,7 @@ describe('Slash Command Dual-Mode System', () => { }) it('should keep modal open for selection', () => { - const mockModalClose = jest.fn() + const mockModalClose = vi.fn() const handler = slashCommandRegistry.findCommand('theme') // For submenu mode, modal should not close immediately @@ -141,12 +140,12 @@ describe('Slash Command Dual-Mode System', () => { const commandWithoutMode: SlashCommandHandler = { name: 'test', description: 'Test command', - search: jest.fn(), - register: jest.fn(), - unregister: jest.fn(), + search: vi.fn(), + register: vi.fn(), + unregister: vi.fn(), } - ;(slashCommandRegistry as any).findCommand = jest.fn(() => commandWithoutMode) + ;(slashCommandRegistry as any).findCommand = vi.fn(() => commandWithoutMode) const handler = slashCommandRegistry.findCommand('test') // Default behavior should be submenu when mode is not specified @@ -189,7 +188,7 @@ describe('Slash Command Dual-Mode System', () => { describe('Command Registration', () => { it('should register both direct and submenu commands', () => { mockDirectCommand.register?.({}) - mockSubmenuCommand.register?.({ setTheme: jest.fn() }) + mockSubmenuCommand.register?.({ setTheme: vi.fn() }) expect(mockDirectCommand.register).toHaveBeenCalled() expect(mockSubmenuCommand.register).toHaveBeenCalled() diff --git a/web/__tests__/navigation-utils.test.ts b/web/__tests__/navigation-utils.test.ts index 3eeba52943..866adea054 100644 --- a/web/__tests__/navigation-utils.test.ts +++ b/web/__tests__/navigation-utils.test.ts @@ -15,12 +15,12 @@ import { } from '@/utils/navigation' // Mock router for testing -const mockPush = jest.fn() +const mockPush = vi.fn() const mockRouter = { push: mockPush } describe('Navigation Utilities', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) describe('createNavigationPath', () => { @@ -63,7 +63,7 @@ describe('Navigation Utilities', () => { configurable: true, }) - const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => { /* noop */ }) const path = createNavigationPath('/datasets/123/documents') expect(path).toBe('/datasets/123/documents') @@ -134,7 +134,7 @@ describe('Navigation Utilities', () => { configurable: true, }) - const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => { /* noop */ }) const params = extractQueryParams(['page', 'limit']) expect(params).toEqual({}) @@ -169,11 +169,11 @@ describe('Navigation Utilities', () => { test('handles errors gracefully', () => { // Mock URLSearchParams to throw an error const originalURLSearchParams = globalThis.URLSearchParams - globalThis.URLSearchParams = jest.fn(() => { + globalThis.URLSearchParams = vi.fn(() => { throw new Error('URLSearchParams error') }) as any - const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => { /* noop */ }) const path = createNavigationPathWithParams('/datasets/123/documents', { page: 1 }) expect(path).toBe('/datasets/123/documents') diff --git a/web/__tests__/real-browser-flicker.test.tsx b/web/__tests__/real-browser-flicker.test.tsx index 0a0ea0c062..c0df6116e2 100644 --- a/web/__tests__/real-browser-flicker.test.tsx +++ b/web/__tests__/real-browser-flicker.test.tsx @@ -76,7 +76,7 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa return mediaQueryList } - jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia) + vi.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia) } // Helper function to create timing page component @@ -240,8 +240,8 @@ const TestThemeProvider = ({ children }: { children: React.ReactNode }) => ( describe('Real Browser Environment Dark Mode Flicker Test', () => { beforeEach(() => { - jest.restoreAllMocks() - jest.clearAllMocks() + vi.restoreAllMocks() + vi.clearAllMocks() if (typeof window !== 'undefined') { try { window.localStorage.clear() @@ -424,12 +424,12 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { setupMockEnvironment(null) const mockStorage = { - getItem: jest.fn(() => { + getItem: vi.fn(() => { throw new Error('LocalStorage access denied') }), - setItem: jest.fn(), - removeItem: jest.fn(), - clear: jest.fn(), + setItem: vi.fn(), + removeItem: vi.fn(), + clear: vi.fn(), } Object.defineProperty(window, 'localStorage', { diff --git a/web/__tests__/workflow-onboarding-integration.test.tsx b/web/__tests__/workflow-onboarding-integration.test.tsx index ded8c75bd1..e4db04148b 100644 --- a/web/__tests__/workflow-onboarding-integration.test.tsx +++ b/web/__tests__/workflow-onboarding-integration.test.tsx @@ -1,15 +1,16 @@ +import type { Mock } from 'vitest' import { BlockEnum } from '@/app/components/workflow/types' import { useWorkflowStore } from '@/app/components/workflow/store' // Type for mocked store type MockWorkflowStore = { showOnboarding: boolean - setShowOnboarding: jest.Mock + setShowOnboarding: Mock hasShownOnboarding: boolean - setHasShownOnboarding: jest.Mock + setHasShownOnboarding: Mock hasSelectedStartNode: boolean - setHasSelectedStartNode: jest.Mock - setShouldAutoOpenStartNodeSelector: jest.Mock + setHasSelectedStartNode: Mock + setShouldAutoOpenStartNodeSelector: Mock notInitialWorkflow: boolean } @@ -20,11 +21,11 @@ type MockNode = { } // Mock zustand store -jest.mock('@/app/components/workflow/store') +vi.mock('@/app/components/workflow/store') // Mock ReactFlow store -const mockGetNodes = jest.fn() -jest.mock('reactflow', () => ({ +const mockGetNodes = vi.fn() +vi.mock('reactflow', () => ({ useStoreApi: () => ({ getState: () => ({ getNodes: mockGetNodes, @@ -33,16 +34,16 @@ jest.mock('reactflow', () => ({ })) describe('Workflow Onboarding Integration Logic', () => { - const mockSetShowOnboarding = jest.fn() - const mockSetHasSelectedStartNode = jest.fn() - const mockSetHasShownOnboarding = jest.fn() - const mockSetShouldAutoOpenStartNodeSelector = jest.fn() + const mockSetShowOnboarding = vi.fn() + const mockSetHasSelectedStartNode = vi.fn() + const mockSetHasShownOnboarding = vi.fn() + const mockSetShouldAutoOpenStartNodeSelector = vi.fn() beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() // Mock store implementation - ;(useWorkflowStore as jest.Mock).mockReturnValue({ + ;(useWorkflowStore as Mock).mockReturnValue({ showOnboarding: false, setShowOnboarding: mockSetShowOnboarding, hasSelectedStartNode: false, @@ -373,12 +374,12 @@ describe('Workflow Onboarding Integration Logic', () => { it('should trigger onboarding for new workflow when draft does not exist', () => { // Simulate the error handling logic from use-workflow-init.ts const error = { - json: jest.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }), + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }), bodyUsed: false, } const mockWorkflowStore = { - setState: jest.fn(), + setState: vi.fn(), } // Simulate error handling @@ -404,7 +405,7 @@ describe('Workflow Onboarding Integration Logic', () => { it('should not trigger onboarding for existing workflows', () => { // Simulate successful draft fetch const mockWorkflowStore = { - setState: jest.fn(), + setState: vi.fn(), } // Normal initialization path should not set showOnboarding: true @@ -419,7 +420,7 @@ describe('Workflow Onboarding Integration Logic', () => { }) it('should create empty draft with proper structure', () => { - const mockSyncWorkflowDraft = jest.fn() + const mockSyncWorkflowDraft = vi.fn() const appId = 'test-app-id' // Simulate the syncWorkflowDraft call from use-workflow-init.ts @@ -467,7 +468,7 @@ describe('Workflow Onboarding Integration Logic', () => { mockGetNodes.mockReturnValue([]) // Mock store with proper state for auto-detection - ;(useWorkflowStore as jest.Mock).mockReturnValue({ + ;(useWorkflowStore as Mock).mockReturnValue({ showOnboarding: false, hasShownOnboarding: false, notInitialWorkflow: false, @@ -550,7 +551,7 @@ describe('Workflow Onboarding Integration Logic', () => { mockGetNodes.mockReturnValue([]) // Mock store with hasShownOnboarding = true - ;(useWorkflowStore as jest.Mock).mockReturnValue({ + ;(useWorkflowStore as Mock).mockReturnValue({ showOnboarding: false, hasShownOnboarding: true, // Already shown in this session notInitialWorkflow: false, @@ -584,7 +585,7 @@ describe('Workflow Onboarding Integration Logic', () => { mockGetNodes.mockReturnValue([]) // Mock store with notInitialWorkflow = true (initial creation) - ;(useWorkflowStore as jest.Mock).mockReturnValue({ + ;(useWorkflowStore as Mock).mockReturnValue({ showOnboarding: false, hasShownOnboarding: false, notInitialWorkflow: true, // Initial workflow creation diff --git a/web/__tests__/workflow-parallel-limit.test.tsx b/web/__tests__/workflow-parallel-limit.test.tsx index 64e9d328f0..8d845794da 100644 --- a/web/__tests__/workflow-parallel-limit.test.tsx +++ b/web/__tests__/workflow-parallel-limit.test.tsx @@ -19,7 +19,7 @@ function setupEnvironment(value?: string) { delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT // Clear module cache to force re-evaluation - jest.resetModules() + vi.resetModules() } function restoreEnvironment() { @@ -28,11 +28,11 @@ function restoreEnvironment() { else delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT - jest.resetModules() + vi.resetModules() } // Mock i18next with proper implementation -jest.mock('react-i18next', () => ({ +vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string) => { if (key.includes('MaxParallelismTitle')) return 'Max Parallelism' @@ -45,20 +45,20 @@ jest.mock('react-i18next', () => ({ }), initReactI18next: { type: '3rdParty', - init: jest.fn(), + init: vi.fn(), }, })) // Mock i18next module completely to prevent initialization issues -jest.mock('i18next', () => ({ - use: jest.fn().mockReturnThis(), - init: jest.fn().mockReturnThis(), - t: jest.fn(key => key), +vi.mock('i18next', () => ({ + use: vi.fn().mockReturnThis(), + init: vi.fn().mockReturnThis(), + t: vi.fn(key => key), isInitialized: true, })) // Mock the useConfig hook -jest.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({ +vi.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({ __esModule: true, default: () => ({ inputs: { @@ -66,82 +66,39 @@ jest.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({ parallel_nums: 5, error_handle_mode: 'terminated', }, - changeParallel: jest.fn(), - changeParallelNums: jest.fn(), - changeErrorHandleMode: jest.fn(), + changeParallel: vi.fn(), + changeParallelNums: vi.fn(), + changeErrorHandleMode: vi.fn(), }), })) // Mock other components -jest.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => { - return function MockVarReferencePicker() { +vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({ + default: function MockVarReferencePicker() { return
VarReferencePicker
- } -}) + }, +})) -jest.mock('@/app/components/workflow/nodes/_base/components/split', () => { - return function MockSplit() { +vi.mock('@/app/components/workflow/nodes/_base/components/split', () => ({ + default: function MockSplit() { return
Split
- } -}) + }, +})) -jest.mock('@/app/components/workflow/nodes/_base/components/field', () => { - return function MockField({ title, children }: { title: string, children: React.ReactNode }) { +vi.mock('@/app/components/workflow/nodes/_base/components/field', () => ({ + default: function MockField({ title, children }: { title: string, children: React.ReactNode }) { return (
{children}
) - } -}) + }, +})) -jest.mock('@/app/components/base/switch', () => { - return function MockSwitch({ defaultValue }: { defaultValue: boolean }) { - return - } -}) - -jest.mock('@/app/components/base/select', () => { - return function MockSelect() { - return - } -}) - -// Use defaultValue to avoid controlled input warnings -jest.mock('@/app/components/base/slider', () => { - return function MockSlider({ value, max, min }: { value: number, max: number, min: number }) { - return ( - - ) - } -}) - -// Use defaultValue to avoid controlled input warnings -jest.mock('@/app/components/base/input', () => { - return function MockInput({ type, max, min, value }: { type: string, max: number, min: number, value: number }) { - return ( - - ) - } +const getParallelControls = () => ({ + numberInput: screen.getByRole('spinbutton'), + slider: screen.getByRole('slider'), }) describe('MAX_PARALLEL_LIMIT Configuration Bug', () => { @@ -160,7 +117,7 @@ describe('MAX_PARALLEL_LIMIT Configuration Bug', () => { } beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) afterEach(() => { @@ -172,115 +129,114 @@ describe('MAX_PARALLEL_LIMIT Configuration Bug', () => { }) describe('Environment Variable Parsing', () => { - it('should parse MAX_PARALLEL_LIMIT from NEXT_PUBLIC_MAX_PARALLEL_LIMIT environment variable', () => { + it('should parse MAX_PARALLEL_LIMIT from NEXT_PUBLIC_MAX_PARALLEL_LIMIT environment variable', async () => { setupEnvironment('25') - const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MAX_PARALLEL_LIMIT } = await import('@/config') expect(MAX_PARALLEL_LIMIT).toBe(25) }) - it('should fallback to default when environment variable is not set', () => { + it('should fallback to default when environment variable is not set', async () => { setupEnvironment() // No environment variable - const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MAX_PARALLEL_LIMIT } = await import('@/config') expect(MAX_PARALLEL_LIMIT).toBe(10) }) - it('should handle invalid environment variable values', () => { + it('should handle invalid environment variable values', async () => { setupEnvironment('invalid') - const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MAX_PARALLEL_LIMIT } = await import('@/config') // Should fall back to default when parsing fails expect(MAX_PARALLEL_LIMIT).toBe(10) }) - it('should handle empty environment variable', () => { + it('should handle empty environment variable', async () => { setupEnvironment('') - const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MAX_PARALLEL_LIMIT } = await import('@/config') // Should fall back to default when empty expect(MAX_PARALLEL_LIMIT).toBe(10) }) // Edge cases for boundary values - it('should clamp MAX_PARALLEL_LIMIT to MIN when env is 0 or negative', () => { + it('should clamp MAX_PARALLEL_LIMIT to MIN when env is 0 or negative', async () => { setupEnvironment('0') - let { MAX_PARALLEL_LIMIT } = require('@/config') + let { MAX_PARALLEL_LIMIT } = await import('@/config') expect(MAX_PARALLEL_LIMIT).toBe(10) // Falls back to default setupEnvironment('-5') - ;({ MAX_PARALLEL_LIMIT } = require('@/config')) + ;({ MAX_PARALLEL_LIMIT } = await import('@/config')) expect(MAX_PARALLEL_LIMIT).toBe(10) // Falls back to default }) - it('should handle float numbers by parseInt behavior', () => { + it('should handle float numbers by parseInt behavior', async () => { setupEnvironment('12.7') - const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MAX_PARALLEL_LIMIT } = await import('@/config') // parseInt truncates to integer expect(MAX_PARALLEL_LIMIT).toBe(12) }) }) describe('UI Component Integration (Main Fix Verification)', () => { - it('should render iteration panel with environment-configured max value', () => { + it('should render iteration panel with environment-configured max value', async () => { // Set environment variable to a different value setupEnvironment('30') // Import Panel after setting environment - const Panel = require('@/app/components/workflow/nodes/iteration/panel').default - const { MAX_PARALLEL_LIMIT } = require('@/config') + const Panel = await import('@/app/components/workflow/nodes/iteration/panel').then(mod => mod.default) + const { MAX_PARALLEL_LIMIT } = await import('@/config') render( , ) // Behavior-focused assertion: UI max should equal MAX_PARALLEL_LIMIT - const numberInput = screen.getByTestId('number-input') - expect(numberInput).toHaveAttribute('data-max', String(MAX_PARALLEL_LIMIT)) - - const slider = screen.getByTestId('slider') - expect(slider).toHaveAttribute('data-max', String(MAX_PARALLEL_LIMIT)) + const { numberInput, slider } = getParallelControls() + expect(numberInput).toHaveAttribute('max', String(MAX_PARALLEL_LIMIT)) + expect(slider).toHaveAttribute('aria-valuemax', String(MAX_PARALLEL_LIMIT)) // Verify the actual values expect(MAX_PARALLEL_LIMIT).toBe(30) - expect(numberInput.getAttribute('data-max')).toBe('30') - expect(slider.getAttribute('data-max')).toBe('30') + expect(numberInput.getAttribute('max')).toBe('30') + expect(slider.getAttribute('aria-valuemax')).toBe('30') }) - it('should maintain UI consistency with different environment values', () => { + it('should maintain UI consistency with different environment values', async () => { setupEnvironment('15') - const Panel = require('@/app/components/workflow/nodes/iteration/panel').default - const { MAX_PARALLEL_LIMIT } = require('@/config') + const Panel = await import('@/app/components/workflow/nodes/iteration/panel').then(mod => mod.default) + const { MAX_PARALLEL_LIMIT } = await import('@/config') render( , ) // Both input and slider should use the same max value from MAX_PARALLEL_LIMIT - const numberInput = screen.getByTestId('number-input') - const slider = screen.getByTestId('slider') + const { numberInput, slider } = getParallelControls() - expect(numberInput.getAttribute('data-max')).toBe(slider.getAttribute('data-max')) - expect(numberInput.getAttribute('data-max')).toBe(String(MAX_PARALLEL_LIMIT)) + expect(numberInput.getAttribute('max')).toBe(slider.getAttribute('aria-valuemax')) + expect(numberInput.getAttribute('max')).toBe(String(MAX_PARALLEL_LIMIT)) }) }) describe('Legacy Constant Verification (For Transition Period)', () => { // Marked as transition/deprecation tests - it('should maintain MAX_ITERATION_PARALLEL_NUM for backward compatibility', () => { - const { MAX_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + it('should maintain MAX_ITERATION_PARALLEL_NUM for backward compatibility', async () => { + const { MAX_ITERATION_PARALLEL_NUM } = await import('@/app/components/workflow/constants') expect(typeof MAX_ITERATION_PARALLEL_NUM).toBe('number') expect(MAX_ITERATION_PARALLEL_NUM).toBe(10) // Hardcoded legacy value }) - it('should demonstrate MAX_PARALLEL_LIMIT vs legacy constant difference', () => { + it('should demonstrate MAX_PARALLEL_LIMIT vs legacy constant difference', async () => { setupEnvironment('50') - const { MAX_PARALLEL_LIMIT } = require('@/config') - const { MAX_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + const { MAX_PARALLEL_LIMIT } = await import('@/config') + const { MAX_ITERATION_PARALLEL_NUM } = await import('@/app/components/workflow/constants') // MAX_PARALLEL_LIMIT is configurable, MAX_ITERATION_PARALLEL_NUM is not expect(MAX_PARALLEL_LIMIT).toBe(50) @@ -290,9 +246,9 @@ describe('MAX_PARALLEL_LIMIT Configuration Bug', () => { }) describe('Constants Validation', () => { - it('should validate that required constants exist and have correct types', () => { - const { MAX_PARALLEL_LIMIT } = require('@/config') - const { MIN_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + it('should validate that required constants exist and have correct types', async () => { + const { MAX_PARALLEL_LIMIT } = await import('@/config') + const { MIN_ITERATION_PARALLEL_NUM } = await import('@/app/components/workflow/constants') expect(typeof MAX_PARALLEL_LIMIT).toBe('number') expect(typeof MIN_ITERATION_PARALLEL_NUM).toBe('number') expect(MAX_PARALLEL_LIMIT).toBeGreaterThanOrEqual(MIN_ITERATION_PARALLEL_NUM) diff --git a/web/__tests__/xss-prevention.test.tsx b/web/__tests__/xss-prevention.test.tsx index 064c6e08de..235a28af51 100644 --- a/web/__tests__/xss-prevention.test.tsx +++ b/web/__tests__/xss-prevention.test.tsx @@ -7,13 +7,14 @@ import React from 'react' import { cleanup, render } from '@testing-library/react' -import '@testing-library/jest-dom' import BlockInput from '../app/components/base/block-input' import SupportVarInput from '../app/components/workflow/nodes/_base/components/support-var-input' // Mock styles -jest.mock('../app/components/app/configuration/base/var-highlight/style.module.css', () => ({ - item: 'mock-item-class', +vi.mock('../app/components/app/configuration/base/var-highlight/style.module.css', () => ({ + default: { + item: 'mock-item-class', + }, })) describe('XSS Prevention - Block Input and Support Var Input Security', () => { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 1f836de6e6..d5e3c61932 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -16,7 +16,7 @@ import { import { useTranslation } from 'react-i18next' import { useShallow } from 'zustand/react/shallow' import s from './style.module.css' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx index 2bfdece433..dda5dff2b9 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx @@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react' import type { Dayjs } from 'dayjs' import type { FC } from 'react' import React, { useCallback } from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { formatToLocalTime } from '@/utils/format' import { useI18N } from '@/context/i18n' import Picker from '@/app/components/base/date-and-time-picker/date-picker' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx index f99ea52492..0a80bf670d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx @@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select' import type { Item } from '@/app/components/base/select' import dayjs from 'dayjs' import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useTranslation } from 'react-i18next' const today = dayjs() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx index 374dbff203..f93bef526f 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -1,7 +1,8 @@ import React from 'react' import { render } from '@testing-library/react' -import '@testing-library/jest-dom' import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing' +import { normalizeAttrs } from '@/app/components/base/icons/utils' +import iconData from '@/app/components/base/icons/src/public/tracing/OpikIconBig.json' describe('SVG Attribute Error Reproduction', () => { // Capture console errors @@ -10,7 +11,7 @@ describe('SVG Attribute Error Reproduction', () => { beforeEach(() => { errorMessages = [] - console.error = jest.fn((message) => { + console.error = vi.fn((message) => { errorMessages.push(message) originalError(message) }) @@ -54,9 +55,6 @@ describe('SVG Attribute Error Reproduction', () => { it('should analyze the SVG structure causing the errors', () => { console.log('\n=== ANALYZING SVG STRUCTURE ===') - // Import the JSON data directly - const iconData = require('@/app/components/base/icons/src/public/tracing/OpikIconBig.json') - console.log('Icon structure analysis:') console.log('- Root element:', iconData.icon.name) console.log('- Children count:', iconData.icon.children?.length || 0) @@ -113,8 +111,6 @@ describe('SVG Attribute Error Reproduction', () => { it('should test the normalizeAttrs function behavior', () => { console.log('\n=== TESTING normalizeAttrs FUNCTION ===') - const { normalizeAttrs } = require('@/app/components/base/icons/utils') - const testAttributes = { 'inkscape:showpageshadow': '2', 'inkscape:pageopacity': '0.0', diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx index 246a1eb6a3..17c919bf22 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx @@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react' import type { PopupProps } from './config-popup' import ConfigPopup from './config-popup' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 628eb13071..767ccb8c59 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' const I18N_PREFIX = 'app.tracing' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx index eecd356e08..e170159e35 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Input from '@/app/components/base/input' type Props = { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 2c17931b83..319ff3f423 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS import { TracingProvider } from './type' import TracingIcon from './tracing-icon' import ConfigButton from './config-button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing' import Indicator from '@/app/components/header/indicator' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index ac1704d60d..0779689c76 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -6,7 +6,7 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { TracingProvider } from './type' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing' import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx index ec9117dd38..aeca1cd3ab 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing' type Props = { diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 3effb79f20..3581587b54 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use import useDocumentTitle from '@/hooks/use-document-title' import ExtraInfo from '@/app/components/datasets/extra-info' import { useEventEmitterContextContext } from '@/context/event-emitter' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type IAppDetailLayoutProps = { children: React.ReactNode diff --git a/web/app/(shareLayout)/webapp-reset-password/layout.tsx b/web/app/(shareLayout)/webapp-reset-password/layout.tsx index e0ac6b9ad6..13073b0e6a 100644 --- a/web/app/(shareLayout)/webapp-reset-password/layout.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/layout.tsx @@ -1,7 +1,7 @@ 'use client' import Header from '@/app/signin/_header' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' export default function SignInLayout({ children }: any) { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 5e3f6fff1d..843f10e039 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -2,7 +2,7 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter, useSearchParams } from 'next/navigation' -import cn from 'classnames' +import { cn } from '@/utils/classnames' import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' import Button from '@/app/components/base/button' diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx index 7649982072..c75f925d40 100644 --- a/web/app/(shareLayout)/webapp-signin/layout.tsx +++ b/web/app/(shareLayout)/webapp-signin/layout.tsx @@ -1,6 +1,6 @@ 'use client' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import type { PropsWithChildren } from 'react' diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index 219722eef3..a14bfcd737 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading' import MailAndCodeAuth from './components/mail-and-code-auth' import MailAndPasswordAuth from './components/mail-and-password-auth' import SSOAuth from './components/sso-auth' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { LicenseStatus } from '@/types/feature' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index 2ab676d6b6..b70ab210d0 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -1,7 +1,7 @@ 'use client' import Header from '@/app/signin/_header' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import { AppContextProvider } from '@/context/app-context' diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index d9d07cbfa1..11fc4866f3 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,13 +1,13 @@ 'use client' +import { useEffect } from 'react' import { useTranslation } from 'react-i18next' -import useSWR from 'swr' import { useRouter, useSearchParams } from 'next/navigation' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Button from '@/app/components/base/button' -import { invitationCheck } from '@/service/common' import Loading from '@/app/components/base/loading' import useDocumentTitle from '@/hooks/use-document-title' +import { useInvitationCheck } from '@/service/use-common' const ActivateForm = () => { useDocumentTitle('') @@ -26,19 +26,21 @@ const ActivateForm = () => { token, }, } - const { data: checkRes } = useSWR(checkParams, invitationCheck, { - revalidateOnFocus: false, - onSuccess(data) { - if (data.is_valid) { - const params = new URLSearchParams(searchParams) - const { email, workspace_id } = data.data - params.set('email', encodeURIComponent(email)) - params.set('workspace_id', encodeURIComponent(workspace_id)) - params.set('invite_token', encodeURIComponent(token as string)) - router.replace(`/signin?${params.toString()}`) - } - }, - }) + const { data: checkRes } = useInvitationCheck({ + ...checkParams.params, + token: token || undefined, + }, true) + + useEffect(() => { + if (checkRes?.is_valid) { + const params = new URLSearchParams(searchParams) + const { email, workspace_id } = checkRes.data + params.set('email', encodeURIComponent(email)) + params.set('workspace_id', encodeURIComponent(workspace_id)) + params.set('invite_token', encodeURIComponent(token as string)) + router.replace(`/signin?${params.toString()}`) + } + }, [checkRes, router, searchParams, token]) return (
{ diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index f143c2fcef..1b4377c10a 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie import type { Operation } from './app-operations' import AppOperations from './app-operations' import dynamic from 'next/dynamic' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AppModeEnum } from '@/types/app' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { diff --git a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx index 3c5d38dd82..04634906af 100644 --- a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx @@ -16,7 +16,7 @@ import AppInfo from './app-info' import NavLink from './navLink' import { useStore as useAppStore } from '@/app/components/app/store' import type { NavIcon } from './navLink' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AppModeEnum } from '@/types/app' type Props = { diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index ff110f70bd..dc46af2d02 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' import ActionButton from '../../base/action-button' import { RiMoreFill } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Menu from './menu' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' diff --git a/web/app/components/app-sidebar/dataset-info/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/index.spec.tsx new file mode 100644 index 0000000000..dd7d7010e8 --- /dev/null +++ b/web/app/components/app-sidebar/dataset-info/index.spec.tsx @@ -0,0 +1,379 @@ +import React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DatasetInfo from './index' +import Dropdown from './dropdown' +import Menu from './menu' +import MenuItem from './menu-item' +import type { DataSet } from '@/models/datasets' +import { + ChunkingMode, + DataSourceType, + DatasetPermission, +} from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { RiEditLine } from '@remixicon/react' + +let mockDataset: DataSet +let mockIsDatasetOperator = false +const mockReplace = vi.fn() +const mockInvalidDatasetList = vi.fn() +const mockInvalidDatasetDetail = vi.fn() +const mockExportPipeline = vi.fn() +const mockCheckIsUsedInApp = vi.fn() +const mockDeleteDataset = vi.fn() + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'dataset-1', + name: 'Dataset Name', + indexing_status: 'completed', + icon_info: { + icon: '📙', + icon_background: '#FFF4ED', + icon_type: 'emoji', + icon_url: '', + }, + description: 'Dataset description', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 1690000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 1, + total_document_count: 1, + word_count: 1000, + provider: 'internal', + embedding_model: 'text-embedding-3', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 0, + score_threshold: 0, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + runtime_mode: 'rag_pipeline', + enable_api: false, + is_multimodal: false, + ...overrides, +}) + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), +})) + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }), +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) => + selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + datasetDetailQueryKeyPrefix: ['dataset', 'detail'], + useInvalidDatasetList: () => mockInvalidDatasetList, +})) + +vi.mock('@/service/use-base', () => ({ + useInvalid: () => mockInvalidDatasetDetail, +})) + +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipeline, + }), +})) + +vi.mock('@/service/datasets', () => ({ + checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args), + deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args), +})) + +vi.mock('@/hooks/use-knowledge', () => ({ + useKnowledge: () => ({ + formatIndexingTechniqueAndMethod: () => 'indexing-technique', + }), +})) + +vi.mock('@/app/components/datasets/rename-modal', () => ({ + __esModule: true, + default: ({ + show, + onClose, + onSuccess, + }: { + show: boolean + onClose: () => void + onSuccess?: () => void + }) => { + if (!show) + return null + return ( +
+ + +
+ ) + }, +})) + +const openMenu = async (user: ReturnType) => { + const trigger = screen.getByRole('button') + await user.click(trigger) +} + +describe('DatasetInfo', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createDataset() + mockIsDatasetOperator = false + }) + + // Rendering of dataset summary details based on expand and dataset state. + describe('Rendering', () => { + it('should show dataset details when expanded', () => { + // Arrange + mockDataset = createDataset({ is_published: true }) + render() + + // Assert + expect(screen.getByText('Dataset Name')).toBeInTheDocument() + expect(screen.getByText('Dataset description')).toBeInTheDocument() + expect(screen.getByText('dataset.chunkingMode.general')).toBeInTheDocument() + expect(screen.getByText('indexing-technique')).toBeInTheDocument() + }) + + it('should show external tag when provider is external', () => { + // Arrange + mockDataset = createDataset({ provider: 'external', is_published: false }) + render() + + // Assert + expect(screen.getByText('dataset.externalTag')).toBeInTheDocument() + expect(screen.queryByText('dataset.chunkingMode.general')).not.toBeInTheDocument() + }) + + it('should hide detailed fields when collapsed', () => { + // Arrange + render() + + // Assert + expect(screen.queryByText('Dataset Name')).not.toBeInTheDocument() + expect(screen.queryByText('Dataset description')).not.toBeInTheDocument() + }) + }) +}) + +describe('MenuItem', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Event handling for menu item interactions. + describe('Interactions', () => { + it('should call handler when clicked', async () => { + const user = userEvent.setup() + const handleClick = vi.fn() + // Arrange + render() + + // Act + await user.click(screen.getByText('Edit')) + + // Assert + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) +}) + +describe('Menu', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createDataset() + }) + + // Rendering of menu options based on runtime mode and delete visibility. + describe('Rendering', () => { + it('should show edit, export, and delete options when rag pipeline and deletable', () => { + // Arrange + mockDataset = createDataset({ runtime_mode: 'rag_pipeline' }) + render( + , + ) + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + expect(screen.getByText('datasetPipeline.operations.exportPipeline')).toBeInTheDocument() + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + + it('should hide export and delete options when not rag pipeline and not deletable', () => { + // Arrange + mockDataset = createDataset({ runtime_mode: 'general' }) + render( + , + ) + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + expect(screen.queryByText('datasetPipeline.operations.exportPipeline')).not.toBeInTheDocument() + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + }) +}) + +describe('Dropdown', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' }) + mockIsDatasetOperator = false + mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' }) + mockCheckIsUsedInApp.mockResolvedValue({ is_using: false }) + mockDeleteDataset.mockResolvedValue({}) + if (!('createObjectURL' in URL)) { + Object.defineProperty(URL, 'createObjectURL', { + value: vi.fn(), + writable: true, + }) + } + if (!('revokeObjectURL' in URL)) { + Object.defineProperty(URL, 'revokeObjectURL', { + value: vi.fn(), + writable: true, + }) + } + }) + + // Rendering behavior based on workspace role. + describe('Rendering', () => { + it('should hide delete option when user is dataset operator', async () => { + const user = userEvent.setup() + // Arrange + mockIsDatasetOperator = true + render() + + // Act + await openMenu(user) + + // Assert + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + }) + + // User interactions that trigger modals and exports. + describe('Interactions', () => { + it('should open rename modal when edit is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.edit')) + + // Assert + expect(screen.getByTestId('rename-modal')).toBeInTheDocument() + }) + + it('should export pipeline when export is clicked', async () => { + const user = userEvent.setup() + const anchorClickSpy = vi.spyOn(HTMLAnchorElement.prototype, 'click') + const createObjectURLSpy = vi.spyOn(URL, 'createObjectURL') + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('datasetPipeline.operations.exportPipeline')) + + // Assert + await waitFor(() => { + expect(mockExportPipeline).toHaveBeenCalledWith({ + pipelineId: 'pipeline-1', + include: false, + }) + }) + expect(createObjectURLSpy).toHaveBeenCalledTimes(1) + expect(anchorClickSpy).toHaveBeenCalledTimes(1) + }) + + it('should show delete confirmation when delete is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.delete')) + + // Assert + await waitFor(() => { + expect(screen.getByText('dataset.deleteDatasetConfirmContent')).toBeInTheDocument() + }) + }) + + it('should delete dataset and redirect when confirm is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.delete')) + await user.click(await screen.findByRole('button', { name: 'common.operation.confirm' })) + + // Assert + await waitFor(() => { + expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1') + }) + expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1) + expect(mockReplace).toHaveBeenCalledWith('/datasets') + }) + }) +}) diff --git a/web/app/components/app-sidebar/dataset-info/index.tsx b/web/app/components/app-sidebar/dataset-info/index.tsx index 44b0baa72b..bace656d54 100644 --- a/web/app/components/app-sidebar/dataset-info/index.tsx +++ b/web/app/components/app-sidebar/dataset-info/index.tsx @@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import type { DataSet } from '@/models/datasets' import { DOC_FORM_TEXT } from '@/models/datasets' import { useKnowledge } from '@/hooks/use-knowledge' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Dropdown from './dropdown' type DatasetInfoProps = { diff --git a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx index ac07333712..cf380d00d2 100644 --- a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx @@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon' import Divider from '../base/divider' import NavLink from './navLink' import type { NavIcon } from './navLink' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import Effect from '../base/effect' import Dropdown from './dataset-info/dropdown' diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index 86de2e2034..fe52c4cfa2 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { useStore as useAppStore } from '@/app/components/app/store' import { useEventEmitterContextContext } from '@/context/event-emitter' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Divider from '../base/divider' import { useHover, useKeyPress } from 'ahooks' import ToggleButton from './toggle-button' diff --git a/web/app/components/app-sidebar/navLink.spec.tsx b/web/app/components/app-sidebar/navLink.spec.tsx index 51f62e669b..3a188eda68 100644 --- a/web/app/components/app-sidebar/navLink.spec.tsx +++ b/web/app/components/app-sidebar/navLink.spec.tsx @@ -1,24 +1,23 @@ import React from 'react' import { render, screen } from '@testing-library/react' -import '@testing-library/jest-dom' import NavLink from './navLink' import type { NavLinkProps } from './navLink' // Mock Next.js navigation -jest.mock('next/navigation', () => ({ +vi.mock('next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) // Mock Next.js Link component -jest.mock('next/link', () => { - return function MockLink({ children, href, className, title }: any) { +vi.mock('next/link', () => ({ + default: function MockLink({ children, href, className, title }: any) { return ( {children} ) - } -}) + }, +})) // Mock RemixIcon components const MockIcon = ({ className }: { className?: string }) => ( @@ -38,7 +37,7 @@ describe('NavLink Animation and Layout Issues', () => { beforeEach(() => { // Mock getComputedStyle for transition testing Object.defineProperty(window, 'getComputedStyle', { - value: jest.fn((element) => { + value: vi.fn((element) => { const isExpanded = element.getAttribute('data-mode') === 'expand' return { transition: 'all 0.3s ease', diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx index ad90b91250..f6d8e57682 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/navLink.tsx @@ -2,7 +2,7 @@ import React from 'react' import { useSelectedLayoutSegment } from 'next/navigation' import Link from 'next/link' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { RemixiconComponentType } from '@remixicon/react' export type NavIcon = React.ComponentType< @@ -42,7 +42,7 @@ const NavLink = ({ const NavIcon = isActive ? iconMap.selected : iconMap.normal const renderIcon = () => ( -
+
) @@ -53,21 +53,17 @@ const NavLink = ({ key={name} type='button' disabled - className={classNames( - 'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', - 'pl-3 pr-1', - )} + className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', + 'pl-3 pr-1')} title={mode === 'collapse' ? name : ''} aria-disabled > {renderIcon()} {name} @@ -79,22 +75,18 @@ const NavLink = ({ {renderIcon()} {name} diff --git a/web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx b/web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx index 54dde5fbd4..dd3b230e9b 100644 --- a/web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx +++ b/web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx @@ -1,6 +1,5 @@ import React from 'react' import { fireEvent, render, screen } from '@testing-library/react' -import '@testing-library/jest-dom' // Simple Mock Components that reproduce the exact UI issues const MockNavLink = ({ name, mode }: { name: string; mode: string }) => { @@ -108,7 +107,7 @@ const MockAppInfo = ({ expand }: { expand: boolean }) => { describe('Sidebar Animation Issues Reproduction', () => { beforeEach(() => { // Mock getBoundingClientRect for position testing - Element.prototype.getBoundingClientRect = jest.fn(() => ({ + Element.prototype.getBoundingClientRect = vi.fn(() => ({ width: 200, height: 40, x: 10, @@ -117,7 +116,7 @@ describe('Sidebar Animation Issues Reproduction', () => { right: 210, top: 10, bottom: 50, - toJSON: jest.fn(), + toJSON: vi.fn(), })) }) @@ -152,7 +151,7 @@ describe('Sidebar Animation Issues Reproduction', () => { }) it('should verify sidebar width animation is working correctly', () => { - const handleToggle = jest.fn() + const handleToggle = vi.fn() const { rerender } = render() const container = screen.getByTestId('sidebar-container') diff --git a/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx b/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx index 1612606e9d..c28ba26d30 100644 --- a/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx +++ b/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx @@ -5,15 +5,14 @@ import React from 'react' import { render } from '@testing-library/react' -import '@testing-library/jest-dom' // Mock Next.js navigation -jest.mock('next/navigation', () => ({ +vi.mock('next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) // Mock classnames utility -jest.mock('@/utils/classnames', () => ({ +vi.mock('@/utils/classnames', () => ({ __esModule: true, default: (...classes: any[]) => classes.filter(Boolean).join(' '), })) diff --git a/web/app/components/app-sidebar/toggle-button.tsx b/web/app/components/app-sidebar/toggle-button.tsx index 8de6f887f6..4f69adfc34 100644 --- a/web/app/components/app-sidebar/toggle-button.tsx +++ b/web/app/components/app-sidebar/toggle-button.tsx @@ -1,7 +1,7 @@ import React from 'react' import Button from '../base/button' import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Tooltip from '../base/tooltip' import { useTranslation } from 'react-i18next' import { getKeyboardKeyNameBySystem } from '../workflow/utils' diff --git a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx index 356f813afc..1cbf5d1738 100644 --- a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx @@ -2,19 +2,13 @@ import React from 'react' import { fireEvent, render, screen } from '@testing-library/react' import EditItem, { EditItemType } from './index' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - describe('AddAnnotationModal/EditItem', () => { test('should render query inputs with user avatar and placeholder strings', () => { render( , ) @@ -28,7 +22,7 @@ describe('AddAnnotationModal/EditItem', () => { , ) @@ -38,7 +32,7 @@ describe('AddAnnotationModal/EditItem', () => { }) test('should propagate changes when answer content updates', () => { - const handleChange = jest.fn() + const handleChange = vi.fn() render( ({ - useProviderContext: jest.fn(), +vi.mock('@/context/provider-context', () => ({ + useProviderContext: vi.fn(), })) -const mockToastNotify = jest.fn() -jest.mock('@/app/components/base/toast', () => ({ +const mockToastNotify = vi.fn() +vi.mock('@/app/components/base/toast', () => ({ __esModule: true, default: { - notify: jest.fn(args => mockToastNotify(args)), + notify: vi.fn(args => mockToastNotify(args)), }, })) -jest.mock('@/app/components/billing/annotation-full', () => () =>
) +vi.mock('@/app/components/billing/annotation-full', () => ({ + default: () =>
, +})) -const mockUseProviderContext = useProviderContext as jest.Mock +const mockUseProviderContext = useProviderContext as Mock const getProviderContext = ({ usage = 0, total = 10, enableBilling = false } = {}) => ({ plan: { @@ -30,12 +33,12 @@ const getProviderContext = ({ usage = 0, total = 10, enableBilling = false } = { describe('AddAnnotationModal', () => { const baseProps = { isShow: true, - onHide: jest.fn(), - onAdd: jest.fn(), + onHide: vi.fn(), + onAdd: vi.fn(), } beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() mockUseProviderContext.mockReturnValue(getProviderContext()) }) @@ -78,7 +81,7 @@ describe('AddAnnotationModal', () => { }) test('should call onAdd with form values when create next enabled', async () => { - const onAdd = jest.fn().mockResolvedValue(undefined) + const onAdd = vi.fn().mockResolvedValue(undefined) render() typeQuestion('Question value') @@ -93,7 +96,7 @@ describe('AddAnnotationModal', () => { }) test('should reset fields after saving when create next enabled', async () => { - const onAdd = jest.fn().mockResolvedValue(undefined) + const onAdd = vi.fn().mockResolvedValue(undefined) render() typeQuestion('Question value') @@ -133,7 +136,7 @@ describe('AddAnnotationModal', () => { }) test('should close modal when save completes and create next unchecked', async () => { - const onAdd = jest.fn().mockResolvedValue(undefined) + const onAdd = vi.fn().mockResolvedValue(undefined) render() typeQuestion('Q') diff --git a/web/app/components/app/annotation/batch-action.spec.tsx b/web/app/components/app/annotation/batch-action.spec.tsx new file mode 100644 index 0000000000..70765f6a32 --- /dev/null +++ b/web/app/components/app/annotation/batch-action.spec.tsx @@ -0,0 +1,42 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import BatchAction from './batch-action' + +describe('BatchAction', () => { + const baseProps = { + selectedIds: ['1', '2', '3'], + onBatchDelete: vi.fn(), + onCancel: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should show the selected count and trigger cancel action', () => { + render() + + expect(screen.getByText('3')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + expect(baseProps.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should confirm before running batch delete', async () => { + const onBatchDelete = vi.fn().mockResolvedValue(undefined) + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.delete' })) + await screen.findByText('appAnnotation.list.delete.title') + + await act(async () => { + fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.delete' })[1]) + }) + + await waitFor(() => { + expect(onBatchDelete).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/annotation/batch-action.tsx b/web/app/components/app/annotation/batch-action.tsx index 6e80d0c4c8..6ff392d17e 100644 --- a/web/app/components/app/annotation/batch-action.tsx +++ b/web/app/components/app/annotation/batch-action.tsx @@ -3,7 +3,7 @@ import { RiDeleteBinLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import Divider from '@/app/components/base/divider' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' const i18nPrefix = 'appAnnotation.batchAction' @@ -38,7 +38,7 @@ const BatchAction: FC = ({ setIsNotDeleting() } return ( -
+
diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx new file mode 100644 index 0000000000..eeeed8dcb4 --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx @@ -0,0 +1,72 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import CSVDownload from './csv-downloader' +import I18nContext from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import type { Locale } from '@/i18n-config' + +const downloaderProps: any[] = [] + +vi.mock('react-papaparse', () => ({ + useCSVDownloader: vi.fn(() => ({ + CSVDownloader: ({ children, ...props }: any) => { + downloaderProps.push(props) + return
{children}
+ }, + Type: { Link: 'link' }, + })), +})) + +const renderWithLocale = (locale: Locale) => { + return render( + + + , + ) +} + +describe('CSVDownload', () => { + const englishTemplate = [ + ['question', 'answer'], + ['question1', 'answer1'], + ['question2', 'answer2'], + ] + const chineseTemplate = [ + ['问题', '答案'], + ['问题 1', '答案 1'], + ['问题 2', '答案 2'], + ] + + beforeEach(() => { + downloaderProps.length = 0 + }) + + it('should render the structure preview and pass English template data by default', () => { + renderWithLocale('en-US' as Locale) + + expect(screen.getByText('share.generation.csvStructureTitle')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.batchModal.template')).toBeInTheDocument() + + expect(downloaderProps[0]).toMatchObject({ + filename: 'template-en-US', + type: 'link', + bom: true, + data: englishTemplate, + }) + }) + + it('should switch to the Chinese template when locale matches the secondary language', () => { + const locale = LanguagesSupported[1] as Locale + renderWithLocale(locale) + + expect(downloaderProps[0]).toMatchObject({ + filename: `template-${locale}`, + data: chineseTemplate, + }) + }) +}) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx index d94295c31c..041cd7ec71 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -4,8 +4,8 @@ import CSVUploader, { type Props } from './csv-uploader' import { ToastContext } from '@/app/components/base/toast' describe('CSVUploader', () => { - const notify = jest.fn() - const updateFile = jest.fn() + const notify = vi.fn() + const updateFile = vi.fn() const getDropElements = () => { const title = screen.getByText('appAnnotation.batchModal.csvUploadTitle') @@ -23,18 +23,18 @@ describe('CSVUploader', () => { ...props, } return render( - + , ) } beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) it('should open the file picker when clicking browse', () => { - const clickSpy = jest.spyOn(HTMLInputElement.prototype, 'click') + const clickSpy = vi.spyOn(HTMLInputElement.prototype, 'click') renderComponent() fireEvent.click(screen.getByText('appAnnotation.batchModal.browse')) @@ -100,12 +100,12 @@ describe('CSVUploader', () => { expect(screen.getByText('report')).toBeInTheDocument() expect(screen.getByText('.csv')).toBeInTheDocument() - const clickSpy = jest.spyOn(HTMLInputElement.prototype, 'click') + const clickSpy = vi.spyOn(HTMLInputElement.prototype, 'click') fireEvent.click(screen.getByText('datasetCreation.stepOne.uploader.change')) expect(clickSpy).toHaveBeenCalled() clickSpy.mockRestore() - const valueSetter = jest.spyOn(fileInput, 'value', 'set') + const valueSetter = vi.spyOn(fileInput, 'value', 'set') const removeTrigger = screen.getByTestId('remove-file-button') fireEvent.click(removeTrigger) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index ccad46b860..c9766135df 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -4,7 +4,7 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { RiDeleteBinLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' import Button from '@/app/components/base/button' diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..3d0e799801 --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx @@ -0,0 +1,165 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import BatchModal, { ProcessStatus } from './index' +import { useProviderContext } from '@/context/provider-context' +import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation' +import type { IBatchModalProps } from './index' +import Toast from '@/app/components/base/toast' +import type { Mock } from 'vitest' + +vi.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { + notify: vi.fn(), + }, +})) + +vi.mock('@/service/annotation', () => ({ + annotationBatchImport: vi.fn(), + checkAnnotationBatchImportProgress: vi.fn(), +})) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: vi.fn(), +})) + +vi.mock('./csv-downloader', () => ({ + __esModule: true, + default: () =>
, +})) + +let lastUploadedFile: File | undefined + +vi.mock('./csv-uploader', () => ({ + __esModule: true, + default: ({ file, updateFile }: { file?: File; updateFile: (file?: File) => void }) => ( +
+ + {file && {file.name}} +
+ ), +})) + +vi.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +const mockNotify = Toast.notify as Mock +const useProviderContextMock = useProviderContext as Mock +const annotationBatchImportMock = annotationBatchImport as Mock +const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as Mock + +const renderComponent = (props: Partial = {}) => { + const mergedProps: IBatchModalProps = { + appId: 'app-id', + isShow: true, + onCancel: vi.fn(), + onAdded: vi.fn(), + ...props, + } + return { + ...render(), + props: mergedProps, + } +} + +describe('BatchModal', () => { + beforeEach(() => { + vi.clearAllMocks() + lastUploadedFile = undefined + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }) + }) + + it('should disable run action and show billing hint when annotation quota is full', () => { + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 10 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: true, + }) + + renderComponent() + + expect(screen.getByTestId('annotation-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })).toBeDisabled() + }) + + it('should reset uploader state when modal closes and allow manual cancellation', () => { + const { rerender, props } = renderComponent() + + fireEvent.click(screen.getByTestId('mock-uploader')) + expect(screen.getByTestId('selected-file')).toHaveTextContent('batch.csv') + + rerender() + rerender() + + expect(screen.queryByTestId('selected-file')).toBeNull() + + fireEvent.click(screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' })) + expect(props.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should submit the csv file, poll status, and notify when import completes', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + const { props } = renderComponent() + const fileTrigger = screen.getByTestId('mock-uploader') + fireEvent.click(fileTrigger) + + const runButton = screen.getByRole('button', { name: 'appAnnotation.batchModal.run' }) + expect(runButton).not.toBeDisabled() + + annotationBatchImportMock.mockResolvedValue({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING }) + checkAnnotationBatchImportProgressMock + .mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING }) + .mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.COMPLETED }) + + await act(async () => { + fireEvent.click(runButton) + }) + + await waitFor(() => { + expect(annotationBatchImportMock).toHaveBeenCalledTimes(1) + }) + + const formData = annotationBatchImportMock.mock.calls[0][0].body as FormData + expect(formData.get('file')).toBe(lastUploadedFile) + + await waitFor(() => { + expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(1) + }) + + await act(async () => { + vi.runOnlyPendingTimers() + }) + + await waitFor(() => { + expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(2) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'appAnnotation.batchModal.completed', + }) + expect(props.onAdded).toHaveBeenCalledTimes(1) + expect(props.onCancel).toHaveBeenCalledTimes(1) + }) + vi.useRealTimers() + }) +}) diff --git a/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx index fd6d900aa4..8722f682eb 100644 --- a/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx +++ b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx @@ -2,7 +2,7 @@ import React from 'react' import { fireEvent, render, screen } from '@testing-library/react' import ClearAllAnnotationsConfirmModal from './index' -jest.mock('react-i18next', () => ({ +vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string) => { const translations: Record = { @@ -16,7 +16,7 @@ jest.mock('react-i18next', () => ({ })) beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) describe('ClearAllAnnotationsConfirmModal', () => { @@ -27,8 +27,8 @@ describe('ClearAllAnnotationsConfirmModal', () => { render( , ) @@ -43,8 +43,8 @@ describe('ClearAllAnnotationsConfirmModal', () => { render( , ) @@ -56,8 +56,8 @@ describe('ClearAllAnnotationsConfirmModal', () => { // User confirms or cancels clearing annotations describe('Interactions', () => { test('should trigger onHide when cancel is clicked', () => { - const onHide = jest.fn() - const onConfirm = jest.fn() + const onHide = vi.fn() + const onConfirm = vi.fn() // Arrange render( { }) test('should trigger onConfirm when confirm is clicked', () => { - const onHide = jest.fn() - const onConfirm = jest.fn() + const onHide = vi.fn() + const onConfirm = vi.fn() // Arrange render( { const defaultProps = { type: EditItemType.Query, content: 'Test content', - onSave: jest.fn(), + onSave: vi.fn(), } beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) // Rendering tests (REQUIRED) @@ -167,7 +167,7 @@ describe('EditItem', () => { it('should save new content when save button is clicked', async () => { // Arrange - const mockSave = jest.fn().mockResolvedValue(undefined) + const mockSave = vi.fn().mockResolvedValue(undefined) const props = { ...defaultProps, onSave: mockSave, @@ -223,7 +223,7 @@ describe('EditItem', () => { it('should call onSave with correct content when saving', async () => { // Arrange - const mockSave = jest.fn().mockResolvedValue(undefined) + const mockSave = vi.fn().mockResolvedValue(undefined) const props = { ...defaultProps, onSave: mockSave, @@ -245,9 +245,9 @@ describe('EditItem', () => { expect(mockSave).toHaveBeenCalledWith('Test save content') }) - it('should show delete option when content changes', async () => { + it('should show delete option and restore original content when delete is clicked', async () => { // Arrange - const mockSave = jest.fn().mockResolvedValue(undefined) + const mockSave = vi.fn().mockResolvedValue(undefined) const props = { ...defaultProps, onSave: mockSave, @@ -267,7 +267,13 @@ describe('EditItem', () => { await user.click(screen.getByRole('button', { name: 'common.operation.save' })) // Assert - expect(mockSave).toHaveBeenCalledWith('Modified content') + expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content') + expect(await screen.findByText('common.operation.delete')).toBeInTheDocument() + + await user.click(screen.getByText('common.operation.delete')) + + expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content') + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() }) it('should handle keyboard interactions in edit mode', async () => { @@ -393,5 +399,68 @@ describe('EditItem', () => { expect(screen.queryByRole('textbox')).not.toBeInTheDocument() expect(screen.getByText('Test content')).toBeInTheDocument() }) + + it('should handle save failure gracefully in edit mode', async () => { + // Arrange + const mockSave = vi.fn().mockRejectedValueOnce(new Error('Save failed')) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Enter edit mode and save (should fail) + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.type(textarea, 'New content') + + // Save should fail but not throw + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert - Should remain in edit mode when save fails + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(mockSave).toHaveBeenCalledWith('New content') + }) + + it('should handle delete action failure gracefully', async () => { + // Arrange + const mockSave = vi.fn() + .mockResolvedValueOnce(undefined) // First save succeeds + .mockRejectedValueOnce(new Error('Delete failed')) // Delete fails + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Edit content to show delete button + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified content') + + // Save to create new content + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + await screen.findByText('common.operation.delete') + + // Click delete (should fail but not throw) + await user.click(screen.getByText('common.operation.delete')) + + // Assert - Delete action should handle error gracefully + expect(mockSave).toHaveBeenCalledTimes(2) + expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content') + expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content') + + // When delete fails, the delete button should still be visible (state not changed) + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + expect(screen.getByText('Modified content')).toBeInTheDocument() + }) }) }) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx index e808d0b48a..6ba830967d 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx @@ -6,7 +6,7 @@ import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react' import { Robot, User } from '@/app/components/base/icons/src/public/avatar' import Textarea from '@/app/components/base/textarea' import Button from '@/app/components/base/button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export enum EditItemType { Query = 'query', @@ -52,8 +52,14 @@ const EditItem: FC = ({ }, [content]) const handleSave = async () => { - await onSave(newContent) - setIsEdit(false) + try { + await onSave(newContent) + setIsEdit(false) + } + catch { + // Keep edit mode open when save fails + // Error notification is handled by the parent component + } } const handleCancel = () => { @@ -96,9 +102,16 @@ const EditItem: FC = ({
·
{ - setNewContent(content) - onSave(content) + onClick={async () => { + try { + await onSave(content) + // Only update UI state after successful delete + setNewContent(content) + } + catch { + // Delete action failed - error is already handled by parent + // UI state remains unchanged, user can retry + } }} >
diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx index a2e2527605..e4e9f23505 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx @@ -1,15 +1,20 @@ -import { render, screen } from '@testing-library/react' +import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast' import EditAnnotationModal from './index' -// Mock only external dependencies -jest.mock('@/service/annotation', () => ({ - addAnnotation: jest.fn(), - editAnnotation: jest.fn(), +const { mockAddAnnotation, mockEditAnnotation } = vi.hoisted(() => ({ + mockAddAnnotation: vi.fn(), + mockEditAnnotation: vi.fn(), })) -jest.mock('@/context/provider-context', () => ({ +// Mock only external dependencies +vi.mock('@/service/annotation', () => ({ + addAnnotation: mockAddAnnotation, + editAnnotation: mockEditAnnotation, +})) + +vi.mock('@/context/provider-context', () => ({ useProviderContext: () => ({ plan: { usage: { annotatedResponse: 5 }, @@ -19,16 +24,16 @@ jest.mock('@/context/provider-context', () => ({ }), })) -jest.mock('@/hooks/use-timestamp', () => ({ +vi.mock('@/hooks/use-timestamp', () => ({ __esModule: true, default: () => ({ formatTime: () => '2023-12-01 10:30:00', }), })) -// Note: i18n is automatically mocked by Jest via __mocks__/react-i18next.ts +// Note: i18n is automatically mocked by Vitest via web/vitest.setup.ts -jest.mock('@/app/components/billing/annotation-full', () => ({ +vi.mock('@/app/components/billing/annotation-full', () => ({ __esModule: true, default: () =>
, })) @@ -36,23 +41,18 @@ jest.mock('@/app/components/billing/annotation-full', () => ({ type ToastNotifyProps = Pick type ToastWithNotify = typeof Toast & { notify: (props: ToastNotifyProps) => ToastHandle } const toastWithNotify = Toast as unknown as ToastWithNotify -const toastNotifySpy = jest.spyOn(toastWithNotify, 'notify').mockReturnValue({ clear: jest.fn() }) - -const { addAnnotation: mockAddAnnotation, editAnnotation: mockEditAnnotation } = jest.requireMock('@/service/annotation') as { - addAnnotation: jest.Mock - editAnnotation: jest.Mock -} +const toastNotifySpy = vi.spyOn(toastWithNotify, 'notify').mockReturnValue({ clear: vi.fn() }) describe('EditAnnotationModal', () => { const defaultProps = { isShow: true, - onHide: jest.fn(), + onHide: vi.fn(), appId: 'test-app-id', query: 'Test query', answer: 'Test answer', - onEdited: jest.fn(), - onAdded: jest.fn(), - onRemove: jest.fn(), + onEdited: vi.fn(), + onAdded: vi.fn(), + onRemove: vi.fn(), } afterAll(() => { @@ -60,7 +60,7 @@ describe('EditAnnotationModal', () => { }) beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() mockAddAnnotation.mockResolvedValue({ id: 'test-id', account: { name: 'Test User' }, @@ -168,7 +168,7 @@ describe('EditAnnotationModal', () => { it('should save content when edited', async () => { // Arrange - const mockOnAdded = jest.fn() + const mockOnAdded = vi.fn() const props = { ...defaultProps, onAdded: mockOnAdded, @@ -210,7 +210,7 @@ describe('EditAnnotationModal', () => { describe('API Calls', () => { it('should call addAnnotation when saving new annotation', async () => { // Arrange - const mockOnAdded = jest.fn() + const mockOnAdded = vi.fn() const props = { ...defaultProps, onAdded: mockOnAdded, @@ -247,7 +247,7 @@ describe('EditAnnotationModal', () => { it('should call editAnnotation when updating existing annotation', async () => { // Arrange - const mockOnEdited = jest.fn() + const mockOnEdited = vi.fn() const props = { ...defaultProps, annotationId: 'test-annotation-id', @@ -314,7 +314,7 @@ describe('EditAnnotationModal', () => { it('should call onRemove when removal is confirmed', async () => { // Arrange - const mockOnRemove = jest.fn() + const mockOnRemove = vi.fn() const props = { ...defaultProps, annotationId: 'test-annotation-id', @@ -405,4 +405,276 @@ describe('EditAnnotationModal', () => { expect(editLinks).toHaveLength(1) // Only answer should have edit button }) }) + + // Error Handling (CRITICAL for coverage) + describe('Error Handling', () => { + it('should show error toast and skip callbacks when addAnnotation fails', async () => { + // Arrange + const mockOnAdded = vi.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock API failure + mockAddAnnotation.mockRejectedValueOnce(new Error('API Error')) + + // Act + render() + + // Find and click edit link for query + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + // Find textarea and enter new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + // Click save button + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'API Error', + type: 'error', + }) + }) + expect(mockOnAdded).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show fallback error message when addAnnotation error has no message', async () => { + // Arrange + const mockOnAdded = vi.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + mockAddAnnotation.mockRejectedValueOnce({}) + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionFailed', + type: 'error', + }) + }) + expect(mockOnAdded).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show error toast and skip callbacks when editAnnotation fails', async () => { + // Arrange + const mockOnEdited = vi.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + // Mock API failure + mockEditAnnotation.mockRejectedValueOnce(new Error('API Error')) + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'API Error', + type: 'error', + }) + }) + expect(mockOnEdited).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show fallback error message when editAnnotation error is not an Error instance', async () => { + // Arrange + const mockOnEdited = vi.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + mockEditAnnotation.mockRejectedValueOnce('oops') + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionFailed', + type: 'error', + }) + }) + expect(mockOnEdited).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + }) + + // Billing & Plan Features + describe('Billing & Plan Features', () => { + it('should show createdAt time when provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + createdAt: 1701381000, // 2023-12-01 10:30:00 + } + + // Act + render() + + // Assert - Check that the formatted time appears somewhere in the component + const container = screen.getByRole('dialog') + expect(container).toHaveTextContent('2023-12-01 10:30:00') + }) + + it('should not show createdAt when not provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + // createdAt is undefined + } + + // Act + render() + + // Assert - Should not contain any timestamp + const container = screen.getByRole('dialog') + expect(container).not.toHaveTextContent('2023-12-01 10:30:00') + }) + + it('should display remove section when annotationId exists', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert - Should have remove functionality + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + }) + + // Toast Notifications (Success) + describe('Toast Notifications', () => { + it('should show success notification when save operation completes', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionSuccess', + type: 'success', + }) + }) + }) + }) + + // React.memo Performance Testing + describe('React.memo Performance', () => { + it('should not re-render when props are the same', () => { + // Arrange + const props = { ...defaultProps } + const { rerender } = render() + + // Act - Re-render with same props + rerender() + + // Assert - Component should still be visible (no errors thrown) + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should re-render when props change', () => { + // Arrange + const props = { ...defaultProps } + const { rerender } = render() + + // Act - Re-render with different props + const newProps = { ...props, query: 'New query content' } + rerender() + + // Assert - Should show new content + expect(screen.getByText('New query content')).toBeInTheDocument() + }) + }) }) diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.tsx index 2961ce393c..6172a215e4 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.tsx @@ -53,27 +53,39 @@ const EditAnnotationModal: FC = ({ postQuery = editedContent else postAnswer = editedContent - if (!isAdd) { - await editAnnotation(appId, annotationId, { - message_id: messageId, - question: postQuery, - answer: postAnswer, - }) - onEdited(postQuery, postAnswer) - } - else { - const res: any = await addAnnotation(appId, { - question: postQuery, - answer: postAnswer, - message_id: messageId, - }) - onAdded(res.id, res.account?.name, postQuery, postAnswer) - } + try { + if (!isAdd) { + await editAnnotation(appId, annotationId, { + message_id: messageId, + question: postQuery, + answer: postAnswer, + }) + onEdited(postQuery, postAnswer) + } + else { + const res = await addAnnotation(appId, { + question: postQuery, + answer: postAnswer, + message_id: messageId, + }) + onAdded(res.id, res.account?.name ?? '', postQuery, postAnswer) + } - Toast.notify({ - message: t('common.api.actionSuccess') as string, - type: 'success', - }) + Toast.notify({ + message: t('common.api.actionSuccess') as string, + type: 'success', + }) + } + catch (error) { + const fallbackMessage = t('common.api.actionFailed') as string + const message = error instanceof Error && error.message ? error.message : fallbackMessage + Toast.notify({ + message, + type: 'error', + }) + // Re-throw to preserve edit mode behavior for UI components + throw error + } } const [showModal, setShowModal] = useState(false) diff --git a/web/app/components/app/annotation/empty-element.spec.tsx b/web/app/components/app/annotation/empty-element.spec.tsx new file mode 100644 index 0000000000..56ebb96121 --- /dev/null +++ b/web/app/components/app/annotation/empty-element.spec.tsx @@ -0,0 +1,13 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import EmptyElement from './empty-element' + +describe('EmptyElement', () => { + it('should render the empty state copy and supporting icon', () => { + const { container } = render() + + expect(screen.getByText('appAnnotation.noData.title')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.noData.description')).toBeInTheDocument() + expect(container.querySelector('svg')).not.toBeNull() + }) +}) diff --git a/web/app/components/app/annotation/filter.spec.tsx b/web/app/components/app/annotation/filter.spec.tsx new file mode 100644 index 0000000000..47a758b17a --- /dev/null +++ b/web/app/components/app/annotation/filter.spec.tsx @@ -0,0 +1,71 @@ +import type { Mock } from 'vitest' +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import Filter, { type QueryParam } from './filter' +import useSWR from 'swr' + +vi.mock('swr', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/service/log', () => ({ + fetchAnnotationsCount: vi.fn(), +})) + +const mockUseSWR = useSWR as unknown as Mock + +describe('Filter', () => { + const appId = 'app-1' + const childContent = 'child-content' + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render nothing until annotation count is fetched', () => { + mockUseSWR.mockReturnValue({ data: undefined }) + + const { container } = render( + +
{childContent}
+
, + ) + + expect(container.firstChild).toBeNull() + expect(mockUseSWR).toHaveBeenCalledWith( + { url: `/apps/${appId}/annotations/count` }, + expect.any(Function), + ) + }) + + it('should propagate keyword changes and clearing behavior', () => { + mockUseSWR.mockReturnValue({ data: { total: 20 } }) + const queryParams: QueryParam = { keyword: 'prefill' } + const setQueryParams = vi.fn() + + const { container } = render( + +
{childContent}
+
, + ) + + const input = screen.getByPlaceholderText('common.operation.search') as HTMLInputElement + fireEvent.change(input, { target: { value: 'updated' } }) + expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: 'updated' }) + + const clearButton = input.parentElement?.querySelector('div.cursor-pointer') as HTMLElement + fireEvent.click(clearButton) + expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: '' }) + + expect(container).toHaveTextContent(childContent) + }) +}) diff --git a/web/app/components/app/annotation/header-opts/index.spec.tsx b/web/app/components/app/annotation/header-opts/index.spec.tsx new file mode 100644 index 0000000000..84a1aa86d5 --- /dev/null +++ b/web/app/components/app/annotation/header-opts/index.spec.tsx @@ -0,0 +1,457 @@ +import * as React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import type { ComponentProps } from 'react' +import HeaderOptions from './index' +import I18NContext from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import type { AnnotationItemBasic } from '../type' +import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation' + +vi.mock('@headlessui/react', () => { + type PopoverContextValue = { open: boolean; setOpen: (open: boolean) => void } + type MenuContextValue = { open: boolean; setOpen: (open: boolean) => void } + const PopoverContext = React.createContext(null) + const MenuContext = React.createContext(null) + + const Popover = ({ children }: { children: React.ReactNode | ((props: { open: boolean }) => React.ReactNode) }) => { + const [open, setOpen] = React.useState(false) + const value = React.useMemo(() => ({ open, setOpen }), [open]) + return ( + + {typeof children === 'function' ? children({ open }) : children} + + ) + } + + const PopoverButton = React.forwardRef(({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }, ref: React.Ref) => { + const context = React.useContext(PopoverContext) + const handleClick = () => { + context?.setOpen(!context.open) + onClick?.() + } + return ( + + ) + }) + + const PopoverPanel = React.forwardRef(({ children, ...props }: { children: React.ReactNode | ((props: { close: () => void }) => React.ReactNode) }, ref: React.Ref) => { + const context = React.useContext(PopoverContext) + if (!context?.open) return null + const content = typeof children === 'function' ? children({ close: () => context.setOpen(false) }) : children + return ( +
+ {content} +
+ ) + }) + + const Menu = ({ children }: { children: React.ReactNode }) => { + const [open, setOpen] = React.useState(false) + const value = React.useMemo(() => ({ open, setOpen }), [open]) + return ( + + {children} + + ) + } + + const MenuButton = ({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }) => { + const context = React.useContext(MenuContext) + const handleClick = () => { + context?.setOpen(!context.open) + onClick?.() + } + return ( + + ) + } + + const MenuItems = ({ children, ...props }: { children: React.ReactNode }) => { + const context = React.useContext(MenuContext) + if (!context?.open) return null + return ( +
+ {children} +
+ ) + } + + return { + Dialog: ({ open, children, className }: { open?: boolean; children: React.ReactNode; className?: string }) => { + if (open === false) return null + return ( +
+ {children} +
+ ) + }, + DialogBackdrop: ({ children, className, onClick }: { children?: React.ReactNode; className?: string; onClick?: () => void }) => ( +
+ {children} +
+ ), + DialogPanel: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + DialogTitle: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + Popover, + PopoverButton, + PopoverPanel, + Menu, + MenuButton, + MenuItems, + Transition: ({ show = true, children }: { show?: boolean; children: React.ReactNode }) => (show ? <>{children} : null), + TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}, + } +}) + +let lastCSVDownloaderProps: Record | undefined +const mockCSVDownloader = vi.fn(({ children, ...props }) => { + lastCSVDownloaderProps = props + return ( +
+ {children} +
+ ) +}) + +vi.mock('react-papaparse', () => ({ + useCSVDownloader: () => ({ + CSVDownloader: (props: any) => mockCSVDownloader(props), + Type: { Link: 'link' }, + }), +})) + +vi.mock('@/service/annotation', () => ({ + fetchExportAnnotationList: vi.fn(), + clearAllAnnotations: vi.fn(), +})) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }), +})) + +vi.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +type HeaderOptionsProps = ComponentProps + +const renderComponent = ( + props: Partial = {}, + locale: string = LanguagesSupported[0] as string, +) => { + const defaultProps: HeaderOptionsProps = { + appId: 'test-app-id', + onAdd: vi.fn(), + onAdded: vi.fn(), + controlUpdateList: 0, + ...props, + } + + return render( + + + , + ) +} + +const openOperationsPopover = async (user: ReturnType) => { + const trigger = document.querySelector('button.btn.btn-secondary') as HTMLButtonElement + expect(trigger).toBeTruthy() + await user.click(trigger) +} + +const expandExportMenu = async (user: ReturnType) => { + await openOperationsPopover(user) + const exportLabel = await screen.findByText('appAnnotation.table.header.bulkExport') + const exportButton = exportLabel.closest('button') as HTMLButtonElement + expect(exportButton).toBeTruthy() + await user.click(exportButton) +} + +const getExportButtons = async () => { + const csvLabel = await screen.findByText('CSV') + const jsonLabel = await screen.findByText('JSONL') + const csvButton = csvLabel.closest('button') as HTMLButtonElement + const jsonButton = jsonLabel.closest('button') as HTMLButtonElement + expect(csvButton).toBeTruthy() + expect(jsonButton).toBeTruthy() + return { + csvButton, + jsonButton, + } +} + +const clickOperationAction = async ( + user: ReturnType, + translationKey: string, +) => { + const label = await screen.findByText(translationKey) + const button = label.closest('button') as HTMLButtonElement + expect(button).toBeTruthy() + await user.click(button) +} + +const mockAnnotations: AnnotationItemBasic[] = [ + { + question: 'Question 1', + answer: 'Answer 1', + }, +] + +const mockedFetchAnnotations = vi.mocked(fetchExportAnnotationList) +const mockedClearAllAnnotations = vi.mocked(clearAllAnnotations) + +describe('HeaderOptions', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useRealTimers() + mockCSVDownloader.mockClear() + lastCSVDownloaderProps = undefined + mockedFetchAnnotations.mockResolvedValue({ data: [] }) + }) + + it('should fetch annotations on mount and render enabled export actions when data exist', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + renderComponent() + + await waitFor(() => { + expect(mockedFetchAnnotations).toHaveBeenCalledWith('test-app-id') + }) + + await expandExportMenu(user) + + const { csvButton, jsonButton } = await getExportButtons() + + expect(csvButton).not.toBeDisabled() + expect(jsonButton).not.toBeDisabled() + + await waitFor(() => { + expect(lastCSVDownloaderProps).toMatchObject({ + bom: true, + filename: 'annotations-en-US', + type: 'link', + data: [ + ['Question', 'Answer'], + ['Question 1', 'Answer 1'], + ], + }) + }) + }) + + it('should disable export actions when there are no annotations', async () => { + const user = userEvent.setup() + renderComponent() + + await expandExportMenu(user) + + const { csvButton, jsonButton } = await getExportButtons() + + expect(csvButton).toBeDisabled() + expect(jsonButton).toBeDisabled() + + expect(lastCSVDownloaderProps).toMatchObject({ + data: [['Question', 'Answer']], + }) + }) + + it('should open the add annotation modal and forward the onAdd callback', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + const onAdd = vi.fn().mockResolvedValue(undefined) + renderComponent({ onAdd }) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalled()) + + await user.click( + screen.getByRole('button', { name: 'appAnnotation.table.header.addAnnotation' }), + ) + + await screen.findByText('appAnnotation.addModal.title') + const questionInput = screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder') + const answerInput = screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder') + + await user.type(questionInput, 'Integration question') + await user.type(answerInput, 'Integration answer') + await user.click(screen.getByRole('button', { name: 'common.operation.add' })) + + await waitFor(() => { + expect(onAdd).toHaveBeenCalledWith({ + question: 'Integration question', + answer: 'Integration answer', + }) + }) + }) + + it('should allow bulk import through the batch modal', async () => { + const user = userEvent.setup() + const onAdded = vi.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.bulkImport') + + expect(await screen.findByText('appAnnotation.batchModal.title')).toBeInTheDocument() + await user.click( + screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }), + ) + expect(onAdded).not.toHaveBeenCalled() + }) + + it('should trigger JSONL download with locale-specific filename', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + const originalCreateElement = document.createElement.bind(document) + const anchor = originalCreateElement('a') as HTMLAnchorElement + const clickSpy = vi.spyOn(anchor, 'click').mockImplementation(vi.fn()) + const createElementSpy = vi.spyOn(document, 'createElement') + .mockImplementation((tagName: Parameters[0]) => { + if (tagName === 'a') + return anchor + return originalCreateElement(tagName) + }) + let capturedBlob: Blob | null = null + const objectURLSpy = vi.spyOn(URL, 'createObjectURL') + .mockImplementation((blob) => { + capturedBlob = blob as Blob + return 'blob://mock-url' + }) + const revokeSpy = vi.spyOn(URL, 'revokeObjectURL').mockImplementation(vi.fn()) + + renderComponent({}, LanguagesSupported[1] as string) + + await expandExportMenu(user) + + await waitFor(() => expect(mockCSVDownloader).toHaveBeenCalled()) + + const { jsonButton } = await getExportButtons() + await user.click(jsonButton) + + expect(createElementSpy).toHaveBeenCalled() + expect(anchor.download).toBe(`annotations-${LanguagesSupported[1]}.jsonl`) + expect(clickSpy).toHaveBeenCalled() + expect(revokeSpy).toHaveBeenCalledWith('blob://mock-url') + + // Verify the blob was created with correct content + expect(capturedBlob).toBeInstanceOf(Blob) + expect(capturedBlob!.type).toBe('application/jsonl') + + const blobContent = await new Promise((resolve) => { + const reader = new FileReader() + reader.onload = () => resolve(reader.result as string) + reader.readAsText(capturedBlob!) + }) + const lines = blobContent.trim().split('\n') + expect(lines).toHaveLength(1) + expect(JSON.parse(lines[0])).toEqual({ + messages: [ + { role: 'system', content: '' }, + { role: 'user', content: 'Question 1' }, + { role: 'assistant', content: 'Answer 1' }, + ], + }) + + clickSpy.mockRestore() + createElementSpy.mockRestore() + objectURLSpy.mockRestore() + revokeSpy.mockRestore() + }) + + it('should clear all annotations when confirmation succeeds', async () => { + mockedClearAllAnnotations.mockResolvedValue(undefined) + const user = userEvent.setup() + const onAdded = vi.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.clearAll') + + await screen.findByText('appAnnotation.table.header.clearAllConfirm') + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + await waitFor(() => { + expect(mockedClearAllAnnotations).toHaveBeenCalledWith('test-app-id') + expect(onAdded).toHaveBeenCalled() + }) + }) + + it('should handle clear all failures gracefully', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(vi.fn()) + mockedClearAllAnnotations.mockRejectedValue(new Error('network')) + const user = userEvent.setup() + const onAdded = vi.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.clearAll') + await screen.findByText('appAnnotation.table.header.clearAllConfirm') + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + await waitFor(() => { + expect(mockedClearAllAnnotations).toHaveBeenCalled() + expect(onAdded).not.toHaveBeenCalled() + expect(consoleSpy).toHaveBeenCalled() + }) + + consoleSpy.mockRestore() + }) + + it('should refetch annotations when controlUpdateList changes', async () => { + const view = renderComponent({ controlUpdateList: 0 }) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(1)) + + view.rerender( + + + , + ) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(2)) + }) +}) diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 024f75867c..5f8ef658e7 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -17,7 +17,7 @@ import Button from '../../../base/button' import AddAnnotationModal from '../add-annotation-modal' import type { AnnotationItemBasic } from '../type' import BatchAddModal from '../batch-add-annotation-modal' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import CustomPopover from '@/app/components/base/popover' import { FileDownload02, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' diff --git a/web/app/components/app/annotation/index.spec.tsx b/web/app/components/app/annotation/index.spec.tsx new file mode 100644 index 0000000000..43c718d235 --- /dev/null +++ b/web/app/components/app/annotation/index.spec.tsx @@ -0,0 +1,242 @@ +import type { Mock } from 'vitest' +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import Annotation from './index' +import type { AnnotationItem } from './type' +import { JobStatus } from './type' +import { type App, AppModeEnum } from '@/types/app' +import { + addAnnotation, + delAnnotation, + delAnnotations, + fetchAnnotationConfig, + fetchAnnotationList, + queryAnnotationJobStatus, +} from '@/service/annotation' +import { useProviderContext } from '@/context/provider-context' +import Toast from '@/app/components/base/toast' + +vi.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { notify: vi.fn() }, +})) + +vi.mock('ahooks', () => ({ + useDebounce: (value: any) => value, +})) + +vi.mock('@/service/annotation', () => ({ + addAnnotation: vi.fn(), + delAnnotation: vi.fn(), + delAnnotations: vi.fn(), + fetchAnnotationConfig: vi.fn(), + editAnnotation: vi.fn(), + fetchAnnotationList: vi.fn(), + queryAnnotationJobStatus: vi.fn(), + updateAnnotationScore: vi.fn(), + updateAnnotationStatus: vi.fn(), +})) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: vi.fn(), +})) + +vi.mock('./filter', () => ({ + default: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +vi.mock('./empty-element', () => ({ + default: () =>
, +})) + +vi.mock('./header-opts', () => ({ + default: (props: any) => ( +
+ +
+ ), +})) + +let latestListProps: any + +vi.mock('./list', () => ({ + default: (props: any) => { + latestListProps = props + if (!props.list.length) + return
+ return ( +
+ + + +
+ ) + }, +})) + +vi.mock('./view-annotation-modal', () => ({ + default: (props: any) => { + if (!props.isShow) + return null + return ( +
+
{props.item.question}
+ + +
+ ) + }, +})) + +vi.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => ({ default: (props: any) => props.isShow ?
: null })) +vi.mock('@/app/components/billing/annotation-full/modal', () => ({ default: (props: any) => props.show ?
: null })) + +const mockNotify = Toast.notify as Mock +const addAnnotationMock = addAnnotation as Mock +const delAnnotationMock = delAnnotation as Mock +const delAnnotationsMock = delAnnotations as Mock +const fetchAnnotationConfigMock = fetchAnnotationConfig as Mock +const fetchAnnotationListMock = fetchAnnotationList as Mock +const queryAnnotationJobStatusMock = queryAnnotationJobStatus as Mock +const useProviderContextMock = useProviderContext as Mock + +const appDetail = { + id: 'app-id', + mode: AppModeEnum.CHAT, +} as App + +const createAnnotation = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-1', + question: overrides.question ?? 'Question 1', + answer: overrides.answer ?? 'Answer 1', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 0, +}) + +const renderComponent = () => render() + +describe('Annotation', () => { + beforeEach(() => { + vi.clearAllMocks() + latestListProps = undefined + fetchAnnotationConfigMock.mockResolvedValue({ + id: 'config-id', + enabled: false, + embedding_model: { + embedding_model_name: 'model', + embedding_provider_name: 'provider', + }, + score_threshold: 0.5, + }) + fetchAnnotationListMock.mockResolvedValue({ data: [], total: 0 }) + queryAnnotationJobStatusMock.mockResolvedValue({ job_status: JobStatus.completed }) + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }) + }) + + it('should render empty element when no annotations are returned', async () => { + renderComponent() + + expect(await screen.findByTestId('empty-element')).toBeInTheDocument() + expect(fetchAnnotationListMock).toHaveBeenCalledWith(appDetail.id, expect.objectContaining({ + page: 1, + keyword: '', + })) + }) + + it('should handle annotation creation and refresh list data', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + addAnnotationMock.mockResolvedValue(undefined) + + renderComponent() + + await screen.findByTestId('list') + fireEvent.click(screen.getByTestId('trigger-add')) + + await waitFor(() => { + expect(addAnnotationMock).toHaveBeenCalledWith(appDetail.id, { question: 'new question', answer: 'new answer' }) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + message: 'common.api.actionSuccess', + type: 'success', + })) + }) + expect(fetchAnnotationListMock).toHaveBeenCalledTimes(2) + }) + + it('should support viewing items and running batch deletion success flow', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + delAnnotationsMock.mockResolvedValue(undefined) + delAnnotationMock.mockResolvedValue(undefined) + + renderComponent() + await screen.findByTestId('list') + + await act(async () => { + latestListProps.onSelectedIdsChange([annotation.id]) + }) + await waitFor(() => { + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + + await act(async () => { + await latestListProps.onBatchDelete() + }) + await waitFor(() => { + expect(delAnnotationsMock).toHaveBeenCalledWith(appDetail.id, [annotation.id]) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + expect(latestListProps.selectedIds).toEqual([]) + }) + + fireEvent.click(screen.getByTestId('list-view')) + expect(screen.getByTestId('view-modal')).toBeInTheDocument() + + await act(async () => { + fireEvent.click(screen.getByTestId('view-modal-remove')) + }) + await waitFor(() => { + expect(delAnnotationMock).toHaveBeenCalledWith(appDetail.id, annotation.id) + }) + }) + + it('should show an error notification when batch deletion fails', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + const error = new Error('failed') + delAnnotationsMock.mockRejectedValue(error) + + renderComponent() + await screen.findByTestId('list') + + await act(async () => { + latestListProps.onSelectedIdsChange([annotation.id]) + }) + await waitFor(() => { + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + + await act(async () => { + await latestListProps.onBatchDelete() + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: error.message, + }) + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + }) +}) diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 32d0c799fc..2d639c91e4 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -25,7 +25,7 @@ import { sleep } from '@/utils' import { useProviderContext } from '@/context/provider-context' import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' import { type App, AppModeEnum } from '@/types/app' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { delAnnotations } from '@/service/annotation' type Props = { diff --git a/web/app/components/app/annotation/list.spec.tsx b/web/app/components/app/annotation/list.spec.tsx new file mode 100644 index 0000000000..8f8eb97d67 --- /dev/null +++ b/web/app/components/app/annotation/list.spec.tsx @@ -0,0 +1,116 @@ +import React from 'react' +import { fireEvent, render, screen, within } from '@testing-library/react' +import List from './list' +import type { AnnotationItem } from './type' + +const mockFormatTime = vi.fn(() => 'formatted-time') + +vi.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: mockFormatTime, + }), +})) + +const createAnnotation = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-id', + question: overrides.question ?? 'question 1', + answer: overrides.answer ?? 'answer 1', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 2, +}) + +const getCheckboxes = (container: HTMLElement) => container.querySelectorAll('[data-testid^="checkbox"]') + +describe('List', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render annotation rows and call onView when clicking a row', () => { + const item = createAnnotation() + const onView = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByText(item.question)) + + expect(onView).toHaveBeenCalledWith(item) + expect(mockFormatTime).toHaveBeenCalledWith(item.created_at, 'appLog.dateTimeFormat') + }) + + it('should toggle single and bulk selection states', () => { + const list = [createAnnotation({ id: 'a', question: 'A' }), createAnnotation({ id: 'b', question: 'B' })] + const onSelectedIdsChange = vi.fn() + const { container, rerender } = render( + , + ) + + const checkboxes = getCheckboxes(container) + fireEvent.click(checkboxes[1]) + expect(onSelectedIdsChange).toHaveBeenCalledWith(['a']) + + rerender( + , + ) + const updatedCheckboxes = getCheckboxes(container) + fireEvent.click(updatedCheckboxes[1]) + expect(onSelectedIdsChange).toHaveBeenCalledWith([]) + + fireEvent.click(updatedCheckboxes[0]) + expect(onSelectedIdsChange).toHaveBeenCalledWith(['a', 'b']) + }) + + it('should confirm before removing an annotation and expose batch actions', async () => { + const item = createAnnotation({ id: 'to-delete', question: 'Delete me' }) + const onRemove = vi.fn() + render( + , + ) + + const row = screen.getByText(item.question).closest('tr') as HTMLTableRowElement + const actionButtons = within(row).getAllByRole('button') + fireEvent.click(actionButtons[1]) + + expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + fireEvent.click(confirmButton) + expect(onRemove).toHaveBeenCalledWith(item.id) + + expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/annotation/list.tsx b/web/app/components/app/annotation/list.tsx index 4135b4362e..62a0c50e60 100644 --- a/web/app/components/app/annotation/list.tsx +++ b/web/app/components/app/annotation/list.tsx @@ -7,7 +7,7 @@ import type { AnnotationItem } from './type' import RemoveAnnotationConfirmModal from './remove-annotation-confirm-modal' import ActionButton from '@/app/components/base/action-button' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Checkbox from '@/app/components/base/checkbox' import BatchAction from './batch-action' diff --git a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx index 347ba7880b..77648ace02 100644 --- a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx +++ b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx @@ -2,7 +2,7 @@ import React from 'react' import { fireEvent, render, screen } from '@testing-library/react' import RemoveAnnotationConfirmModal from './index' -jest.mock('react-i18next', () => ({ +vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string) => { const translations: Record = { @@ -16,7 +16,7 @@ jest.mock('react-i18next', () => ({ })) beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) describe('RemoveAnnotationConfirmModal', () => { @@ -27,8 +27,8 @@ describe('RemoveAnnotationConfirmModal', () => { render( , ) @@ -43,8 +43,8 @@ describe('RemoveAnnotationConfirmModal', () => { render( , ) @@ -56,8 +56,8 @@ describe('RemoveAnnotationConfirmModal', () => { // User interactions with confirm and cancel buttons describe('Interactions', () => { test('should call onHide when cancel button is clicked', () => { - const onHide = jest.fn() - const onRemove = jest.fn() + const onHide = vi.fn() + const onRemove = vi.fn() // Arrange render( { }) test('should call onRemove when confirm button is clicked', () => { - const onHide = jest.fn() - const onRemove = jest.fn() + const onHide = vi.fn() + const onRemove = vi.fn() // Arrange render( 'formatted-time') + +vi.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: mockFormatTime, + }), +})) + +vi.mock('@/service/annotation', () => ({ + fetchHitHistoryList: vi.fn(), +})) + +vi.mock('../edit-annotation-modal/edit-item', () => { + const EditItemType = { + Query: 'query', + Answer: 'answer', + } + return { + __esModule: true, + default: ({ type, content, onSave }: { type: string; content: string; onSave: (value: string) => void }) => ( +
+
{content}
+ +
+ ), + EditItemType, + } +}) + +const fetchHitHistoryListMock = fetchHitHistoryList as Mock + +const createAnnotationItem = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-id', + question: overrides.question ?? 'question', + answer: overrides.answer ?? 'answer', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 0, +}) + +const createHitHistoryItem = (overrides: Partial = {}): HitHistoryItem => ({ + id: overrides.id ?? 'hit-id', + question: overrides.question ?? 'query', + match: overrides.match ?? 'match', + response: overrides.response ?? 'response', + source: overrides.source ?? 'source', + score: overrides.score ?? 0.42, + created_at: overrides.created_at ?? 1700000000, +}) + +const renderComponent = (props?: Partial>) => { + const item = createAnnotationItem() + const mergedProps: React.ComponentProps = { + appId: 'app-id', + isShow: true, + onHide: vi.fn(), + item, + onSave: vi.fn().mockResolvedValue(undefined), + onRemove: vi.fn().mockResolvedValue(undefined), + ...props, + } + return { + ...render(), + props: mergedProps, + } +} + +describe('ViewAnnotationModal', () => { + beforeEach(() => { + vi.clearAllMocks() + fetchHitHistoryListMock.mockResolvedValue({ data: [], total: 0 }) + }) + + it('should render annotation tab and allow saving updated query', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByTestId('edit-query')) + + // Assert + await waitFor(() => { + expect(props.onSave).toHaveBeenCalledWith('query-updated', props.item.answer) + }) + }) + + it('should render annotation tab and allow saving updated answer', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByTestId('edit-answer')) + + // Assert + await waitFor(() => { + expect(props.onSave).toHaveBeenCalledWith(props.item.question, 'answer-updated') + }, + ) + }) + + it('should switch to hit history tab and show no data message', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByText('appAnnotation.viewModal.hitHistory')) + + // Assert + expect(await screen.findByText('appAnnotation.viewModal.noHitHistory')).toBeInTheDocument() + expect(mockFormatTime).toHaveBeenCalledWith(props.item.created_at, 'appLog.dateTimeFormat') + }) + + it('should render hit history entries with pagination badge when data exists', async () => { + const hits = [createHitHistoryItem({ question: 'user input' }), createHitHistoryItem({ id: 'hit-2', question: 'second' })] + fetchHitHistoryListMock.mockResolvedValue({ data: hits, total: 15 }) + + renderComponent() + + fireEvent.click(await screen.findByText('appAnnotation.viewModal.hitHistory')) + + expect(await screen.findByText('user input')).toBeInTheDocument() + expect(screen.getByText('15 appAnnotation.viewModal.hits')).toBeInTheDocument() + expect(mockFormatTime).toHaveBeenCalledWith(hits[0].created_at, 'appLog.dateTimeFormat') + }) + + it('should confirm before removing the annotation and hide on success', async () => { + const { props } = renderComponent() + + fireEvent.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(props.onRemove).toHaveBeenCalledTimes(1) + expect(props.onHide).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/annotation/view-annotation-modal/index.tsx b/web/app/components/app/annotation/view-annotation-modal/index.tsx index 8426ab0005..d21b177098 100644 --- a/web/app/components/app/annotation/view-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/view-annotation-modal/index.tsx @@ -14,7 +14,7 @@ import TabSlider from '@/app/components/base/tab-slider-plain' import { fetchHitHistoryList } from '@/service/annotation' import { APP_PAGE_LIMIT } from '@/config' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { appId: string diff --git a/web/app/components/app/app-access-control/access-control-dialog.tsx b/web/app/components/app/app-access-control/access-control-dialog.tsx index ee3fa9650b..99cf6d7074 100644 --- a/web/app/components/app/app-access-control/access-control-dialog.tsx +++ b/web/app/components/app/app-access-control/access-control-dialog.tsx @@ -2,7 +2,7 @@ import { Fragment, useCallback } from 'react' import type { ReactNode } from 'react' import { Dialog, Transition } from '@headlessui/react' import { RiCloseLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type DialogProps = { className?: string diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx new file mode 100644 index 0000000000..0948361413 --- /dev/null +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -0,0 +1,389 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AccessControl from './index' +import AccessControlDialog from './access-control-dialog' +import AccessControlItem from './access-control-item' +import AddMemberOrGroupDialog from './add-member-or-group-pop' +import SpecificGroupsOrMembers from './specific-groups-or-members' +import useAccessControlStore from '@/context/access-control-store' +import { useGlobalPublicStore } from '@/context/global-public-context' +import type { AccessControlAccount, AccessControlGroup, Subject } from '@/models/access-control' +import { AccessMode, SubjectType } from '@/models/access-control' +import Toast from '../../base/toast' +import { defaultSystemFeatures } from '@/types/feature' +import type { App } from '@/types/app' + +const mockUseAppWhiteListSubjects = vi.fn() +const mockUseSearchForWhiteListCandidates = vi.fn() +const mockMutateAsync = vi.fn() +const mockUseUpdateAccessMode = vi.fn(() => ({ + isPending: false, + mutateAsync: mockMutateAsync, +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (value: { userProfile: { email: string; id?: string; name?: string; avatar?: string; avatar_url?: string; is_password_set?: boolean } }) => T) => selector({ + userProfile: { + id: 'current-user', + name: 'Current User', + email: 'member@example.com', + avatar: '', + avatar_url: '', + is_password_set: true, + }, + }), +})) + +vi.mock('@/service/common', () => ({ + fetchCurrentWorkspace: vi.fn(), + fetchLangGeniusVersion: vi.fn(), + fetchUserProfile: vi.fn(), + getSystemFeatures: vi.fn(), +})) + +vi.mock('@/service/access-control', () => ({ + useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), + useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), + useUpdateAccessMode: () => mockUseUpdateAccessMode(), +})) + +vi.mock('@headlessui/react', () => { + const DialogComponent: any = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + DialogComponent.Panel = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const DialogTitle = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const DialogDescription = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const TransitionChild = ({ children }: any) => ( + <>{typeof children === 'function' ? children({}) : children} + ) + const Transition = ({ show = true, children }: any) => ( + show ? <>{typeof children === 'function' ? children({}) : children} : null + ) + Transition.Child = TransitionChild + return { + Dialog: DialogComponent, + Transition, + DialogTitle, + Description: DialogDescription, + } +}) + +vi.mock('ahooks', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useDebounce: (value: unknown) => value, + } +}) + +const createGroup = (overrides: Partial = {}): AccessControlGroup => ({ + id: 'group-1', + name: 'Group One', + groupSize: 5, + ...overrides, +} as AccessControlGroup) + +const createMember = (overrides: Partial = {}): AccessControlAccount => ({ + id: 'member-1', + name: 'Member One', + email: 'member@example.com', + avatar: '', + avatarUrl: '', + ...overrides, +} as AccessControlAccount) + +const baseGroup = createGroup() +const baseMember = createMember() +const groupSubject: Subject = { + subjectId: baseGroup.id, + subjectType: SubjectType.GROUP, + groupData: baseGroup, +} as Subject +const memberSubject: Subject = { + subjectId: baseMember.id, + subjectType: SubjectType.ACCOUNT, + accountData: baseMember, +} as Subject + +const resetAccessControlStore = () => { + useAccessControlStore.setState({ + appId: '', + specificGroups: [], + specificMembers: [], + currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS, + selectedGroupsForBreadcrumb: [], + }) +} + +const resetGlobalStore = () => { + useGlobalPublicStore.setState({ + systemFeatures: defaultSystemFeatures, + isGlobalPending: false, + }) +} + +beforeAll(() => { + class MockIntersectionObserver { + observe = vi.fn(() => undefined) + disconnect = vi.fn(() => undefined) + unobserve = vi.fn(() => undefined) + } + // @ts-expect-error jsdom does not implement IntersectionObserver + globalThis.IntersectionObserver = MockIntersectionObserver +}) + +beforeEach(() => { + vi.clearAllMocks() + resetAccessControlStore() + resetGlobalStore() + mockMutateAsync.mockResolvedValue(undefined) + mockUseUpdateAccessMode.mockReturnValue({ + isPending: false, + mutateAsync: mockMutateAsync, + }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: false, + data: { + groups: [baseGroup], + members: [baseMember], + }, + }) + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + data: { pages: [{ currPage: 1, subjects: [groupSubject, memberSubject], hasMore: false }] }, + }) +}) + +// AccessControlItem handles selected vs. unselected styling and click state updates +describe('AccessControlItem', () => { + it('should update current menu when selecting a different access type', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.PUBLIC }) + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + expect(option).toHaveClass('cursor-pointer') + + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) + + it('should keep current menu when clicking the selected access type', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION }) + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) +}) + +// AccessControlDialog renders a headless UI dialog with a manual close control +describe('AccessControlDialog', () => { + it('should render dialog content when visible', () => { + render( + +
Dialog Content
+
, + ) + + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByText('Dialog Content')).toBeInTheDocument() + }) + + it('should trigger onClose when clicking the close control', async () => { + const handleClose = vi.fn() + const { container } = render( + +
Dialog Content
+
, + ) + + const closeButton = container.querySelector('.absolute.right-5.top-5') as HTMLElement + fireEvent.click(closeButton) + + await waitFor(() => { + expect(handleClose).toHaveBeenCalledTimes(1) + }) + }) +}) + +// SpecificGroupsOrMembers syncs store state with fetched data and supports removals +describe('SpecificGroupsOrMembers', () => { + it('should render collapsed view when not in specific selection mode', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION }) + + render() + + expect(screen.getByText('app.accessControlDialog.accessItems.specific')).toBeInTheDocument() + expect(screen.queryByText(baseGroup.name)).not.toBeInTheDocument() + }) + + it('should show loading state while pending', async () => { + useAccessControlStore.setState({ appId: 'app-1', currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: true, + data: undefined, + }) + + const { container } = render() + + await waitFor(() => { + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + }) + + it('should render fetched groups and members and support removal', async () => { + useAccessControlStore.setState({ appId: 'app-1', currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS }) + + render() + + await waitFor(() => { + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + const groupItem = screen.getByText(baseGroup.name).closest('div') + const groupRemove = groupItem?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(groupRemove) + + await waitFor(() => { + expect(screen.queryByText(baseGroup.name)).not.toBeInTheDocument() + }) + + const memberItem = screen.getByText(baseMember.name).closest('div') + const memberRemove = memberItem?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(memberRemove) + + await waitFor(() => { + expect(screen.queryByText(baseMember.name)).not.toBeInTheDocument() + }) + }) +}) + +// AddMemberOrGroupDialog renders search results and updates store selections +describe('AddMemberOrGroupDialog', () => { + it('should open search popover and display candidates', async () => { + const user = userEvent.setup() + + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByPlaceholderText('app.accessControlDialog.operateGroupAndMember.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + it('should allow selecting members and expanding groups', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + const expandButton = screen.getByText('app.accessControlDialog.operateGroupAndMember.expand') + await user.click(expandButton) + expect(useAccessControlStore.getState().selectedGroupsForBreadcrumb).toEqual([baseGroup]) + + const memberLabel = screen.getByText(baseMember.name) + const memberCheckbox = memberLabel.parentElement?.previousElementSibling as HTMLElement + fireEvent.click(memberCheckbox) + + expect(useAccessControlStore.getState().specificMembers).toEqual([baseMember]) + }) + + it('should show empty state when no candidates are returned', async () => { + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + data: { pages: [] }, + }) + + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByText('app.accessControlDialog.operateGroupAndMember.noResult')).toBeInTheDocument() + }) +}) + +// AccessControl integrates dialog, selection items, and confirm flow +describe('AccessControl', () => { + it('should initialize menu from app and call update on confirm', async () => { + const onClose = vi.fn() + const onConfirm = vi.fn() + const toastSpy = vi.spyOn(Toast, 'notify').mockReturnValue({}) + useAccessControlStore.setState({ + specificGroups: [baseGroup], + specificMembers: [baseMember], + }) + const app = { + id: 'app-id-1', + access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + } as App + + render( + , + ) + + await waitFor(() => { + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.SPECIFIC_GROUPS_MEMBERS) + }) + + fireEvent.click(screen.getByText('common.operation.confirm')) + + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalledWith({ + appId: app.id, + accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + subjects: [ + { subjectId: baseGroup.id, subjectType: SubjectType.GROUP }, + { subjectId: baseMember.id, subjectType: SubjectType.ACCOUNT }, + ], + }) + expect(toastSpy).toHaveBeenCalled() + expect(onConfirm).toHaveBeenCalled() + }) + }) + + it('should expose the external members tip when SSO is disabled', () => { + const app = { + id: 'app-id-2', + access_mode: AccessMode.PUBLIC, + } as App + + render( + , + ) + + expect(screen.getByText('app.accessControlDialog.accessItems.external')).toBeInTheDocument() + expect(screen.getByText('app.accessControlDialog.accessItems.anyone')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index e9519aeedf..17263fdd46 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -11,7 +11,7 @@ import Input from '../../base/input' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' import Loading from '../../base/loading' import useAccessControlStore from '../../../../context/access-control-store' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useSearchForWhiteListCandidates } from '@/service/access-control' import type { AccessControlAccount, AccessControlGroup, Subject, SubjectAccount, SubjectGroup } from '@/models/access-control' import { SubjectType } from '@/models/access-control' @@ -32,7 +32,7 @@ export default function AddMemberOrGroupDialog() { const anchorRef = useRef(null) useEffect(() => { - const hasMore = data?.pages?.[0].hasMore ?? false + const hasMore = data?.pages?.[0]?.hasMore ?? false let observer: IntersectionObserver | undefined if (anchorRef.current) { observer = new IntersectionObserver((entries) => { @@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() { setSelectedGroupsForBreadcrumb([]) }, [setSelectedGroupsForBreadcrumb]) return
- 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} + 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} {selectedGroupsForBreadcrumb.map((group, index) => { return
/ @@ -198,7 +198,7 @@ type BaseItemProps = { children: React.ReactNode } function BaseItem({ children, className }: BaseItemProps) { - return
+ return
{children}
} diff --git a/web/app/components/app/app-publisher/suggested-action.tsx b/web/app/components/app/app-publisher/suggested-action.tsx index 2535de6654..154bacc361 100644 --- a/web/app/components/app/app-publisher/suggested-action.tsx +++ b/web/app/components/app/app-publisher/suggested-action.tsx @@ -1,6 +1,6 @@ import type { HTMLProps, PropsWithChildren } from 'react' import { RiArrowRightUpLine } from '@remixicon/react' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type SuggestedActionProps = PropsWithChildren & { icon?: React.ReactNode @@ -19,11 +19,9 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, . href={disabled ? undefined : link} target='_blank' rel='noreferrer' - className={classNames( - 'flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1', + className={cn('flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1', disabled ? 'cursor-not-allowed opacity-30 shadow-xs' : 'cursor-pointer text-text-secondary hover:bg-state-accent-hover hover:text-text-accent', - className, - )} + className)} onClick={handleClick} {...props} > diff --git a/web/app/components/app/configuration/base/feature-panel/index.tsx b/web/app/components/app/configuration/base/feature-panel/index.tsx index ec5ab96d76..c9ebfefbe5 100644 --- a/web/app/components/app/configuration/base/feature-panel/index.tsx +++ b/web/app/components/app/configuration/base/feature-panel/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC, ReactNode } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type IFeaturePanelProps = { className?: string diff --git a/web/app/components/app/configuration/base/group-name/index.spec.tsx b/web/app/components/app/configuration/base/group-name/index.spec.tsx index ac504247f2..be698c3233 100644 --- a/web/app/components/app/configuration/base/group-name/index.spec.tsx +++ b/web/app/components/app/configuration/base/group-name/index.spec.tsx @@ -3,7 +3,7 @@ import GroupName from './index' describe('GroupName', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) describe('Rendering', () => { diff --git a/web/app/components/app/configuration/base/operation-btn/index.spec.tsx b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx index 615a1769e8..5a16135c55 100644 --- a/web/app/components/app/configuration/base/operation-btn/index.spec.tsx +++ b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx @@ -1,7 +1,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import OperationBtn from './index' -jest.mock('@remixicon/react', () => ({ +vi.mock('@remixicon/react', () => ({ RiAddLine: (props: { className?: string }) => ( ), @@ -12,7 +12,7 @@ jest.mock('@remixicon/react', () => ({ describe('OperationBtn', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) // Rendering icons and translation labels @@ -29,7 +29,7 @@ describe('OperationBtn', () => { }) it('should render add icon when type is add', () => { // Arrange - const onClick = jest.fn() + const onClick = vi.fn() // Act render() @@ -57,7 +57,7 @@ describe('OperationBtn', () => { describe('Interactions', () => { it('should execute click handler when button is clicked', () => { // Arrange - const onClick = jest.fn() + const onClick = vi.fn() render() // Act diff --git a/web/app/components/app/configuration/base/operation-btn/index.tsx b/web/app/components/app/configuration/base/operation-btn/index.tsx index aba35cded2..db19d2976e 100644 --- a/web/app/components/app/configuration/base/operation-btn/index.tsx +++ b/web/app/components/app/configuration/base/operation-btn/index.tsx @@ -6,7 +6,7 @@ import { RiAddLine, RiEditLine, } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { noop } from 'lodash-es' export type IOperationBtnProps = { diff --git a/web/app/components/app/configuration/base/var-highlight/index.spec.tsx b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx index 9e84aa09ac..77fe1f2b28 100644 --- a/web/app/components/app/configuration/base/var-highlight/index.spec.tsx +++ b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx @@ -3,7 +3,7 @@ import VarHighlight, { varHighlightHTML } from './index' describe('VarHighlight', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) // Rendering highlighted variable tags @@ -19,7 +19,9 @@ describe('VarHighlight', () => { expect(screen.getByText('userInput')).toBeInTheDocument() expect(screen.getAllByText('{{')[0]).toBeInTheDocument() expect(screen.getAllByText('}}')[0]).toBeInTheDocument() - expect(container.firstChild).toHaveClass('item') + // CSS modules add a hash to class names, so we check that the class attribute contains 'item' + const firstChild = container.firstChild as HTMLElement + expect(firstChild.className).toContain('item') }) it('should apply custom class names when provided', () => { @@ -56,7 +58,9 @@ describe('VarHighlight', () => { const html = varHighlightHTML(props) // Assert - expect(html).toContain('class="item text-primary') + // CSS modules add a hash to class names, so the class attribute may contain _item_xxx + expect(html).toContain('text-primary') + expect(html).toContain('item') }) }) }) diff --git a/web/app/components/app/configuration/base/warning-mask/cannot-query-dataset.spec.tsx b/web/app/components/app/configuration/base/warning-mask/cannot-query-dataset.spec.tsx index d625e9fb72..accbcf9f5d 100644 --- a/web/app/components/app/configuration/base/warning-mask/cannot-query-dataset.spec.tsx +++ b/web/app/components/app/configuration/base/warning-mask/cannot-query-dataset.spec.tsx @@ -4,7 +4,7 @@ import CannotQueryDataset from './cannot-query-dataset' describe('CannotQueryDataset WarningMask', () => { test('should render dataset warning copy and action button', () => { - const onConfirm = jest.fn() + const onConfirm = vi.fn() render() expect(screen.getByText('appDebug.feature.dataSet.queryVariable.unableToQueryDataSet')).toBeInTheDocument() @@ -13,7 +13,7 @@ describe('CannotQueryDataset WarningMask', () => { }) test('should invoke onConfirm when OK button clicked', () => { - const onConfirm = jest.fn() + const onConfirm = vi.fn() render() fireEvent.click(screen.getByRole('button', { name: 'appDebug.feature.dataSet.queryVariable.ok' })) diff --git a/web/app/components/app/configuration/base/warning-mask/formatting-changed.spec.tsx b/web/app/components/app/configuration/base/warning-mask/formatting-changed.spec.tsx index a968bde272..0db857d7c4 100644 --- a/web/app/components/app/configuration/base/warning-mask/formatting-changed.spec.tsx +++ b/web/app/components/app/configuration/base/warning-mask/formatting-changed.spec.tsx @@ -4,8 +4,8 @@ import FormattingChanged from './formatting-changed' describe('FormattingChanged WarningMask', () => { test('should display translation text and both actions', () => { - const onConfirm = jest.fn() - const onCancel = jest.fn() + const onConfirm = vi.fn() + const onCancel = vi.fn() render( { }) test('should call callbacks when buttons are clicked', () => { - const onConfirm = jest.fn() - const onCancel = jest.fn() + const onConfirm = vi.fn() + const onCancel = vi.fn() render( { test('should show default title when trial not finished', () => { - render() + render() expect(screen.getByText('appDebug.notSetAPIKey.title')).toBeInTheDocument() expect(screen.getByText('appDebug.notSetAPIKey.description')).toBeInTheDocument() }) test('should show trail finished title when flag is true', () => { - render() + render() expect(screen.getByText('appDebug.notSetAPIKey.trailFinished')).toBeInTheDocument() }) test('should call onSetting when primary button clicked', () => { - const onSetting = jest.fn() + const onSetting = vi.fn() render() fireEvent.click(screen.getByRole('button', { name: 'appDebug.notSetAPIKey.settingBtn' })) diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index 5bf2f177ff..6492864ce2 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -14,7 +14,7 @@ import s from './style.module.css' import MessageTypeSelector from './message-type-selector' import ConfirmAddVar from './confirm-add-var' import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { PromptRole, PromptVariable } from '@/models/debug' import { Copy, diff --git a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.spec.tsx b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.spec.tsx index 211b43c5ba..2c15a2b9b4 100644 --- a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.spec.tsx +++ b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.spec.tsx @@ -2,18 +2,18 @@ import React from 'react' import { fireEvent, render, screen } from '@testing-library/react' import ConfirmAddVar from './index' -jest.mock('../../base/var-highlight', () => ({ +vi.mock('../../base/var-highlight', () => ({ __esModule: true, default: ({ name }: { name: string }) => {name}, })) describe('ConfirmAddVar', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) it('should render variable names', () => { - render() + render() const highlights = screen.getAllByTestId('var-highlight') expect(highlights).toHaveLength(2) @@ -22,9 +22,9 @@ describe('ConfirmAddVar', () => { }) it('should trigger cancel actions', () => { - const onConfirm = jest.fn() - const onCancel = jest.fn() - render() + const onConfirm = vi.fn() + const onCancel = vi.fn() + render() fireEvent.click(screen.getByText('common.operation.cancel')) @@ -32,9 +32,9 @@ describe('ConfirmAddVar', () => { }) it('should trigger confirm actions', () => { - const onConfirm = jest.fn() - const onCancel = jest.fn() - render() + const onConfirm = vi.fn() + const onCancel = vi.fn() + render() fireEvent.click(screen.getByText('common.operation.add')) diff --git a/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.spec.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.spec.tsx index 2e75cd62ca..a0175dc710 100644 --- a/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.spec.tsx +++ b/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import EditModal from './edit-modal' import type { ConversationHistoriesRole } from '@/models/debug' -jest.mock('@/app/components/base/modal', () => ({ +vi.mock('@/app/components/base/modal', () => ({ __esModule: true, default: ({ children }: { children: React.ReactNode }) =>
{children}
, })) @@ -15,19 +15,19 @@ describe('Conversation history edit modal', () => { } beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) it('should render provided prefixes', () => { - render() + render() expect(screen.getByDisplayValue('user')).toBeInTheDocument() expect(screen.getByDisplayValue('assistant')).toBeInTheDocument() }) it('should update prefixes and save changes', () => { - const onSave = jest.fn() - render() + const onSave = vi.fn() + render() fireEvent.change(screen.getByDisplayValue('user'), { target: { value: 'member' } }) fireEvent.change(screen.getByDisplayValue('assistant'), { target: { value: 'helper' } }) @@ -40,8 +40,8 @@ describe('Conversation history edit modal', () => { }) it('should call close handler', () => { - const onClose = jest.fn() - render() + const onClose = vi.fn() + render() fireEvent.click(screen.getByText('common.operation.cancel')) diff --git a/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx index c92bb48e4a..eaae6bb5b9 100644 --- a/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx +++ b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx @@ -2,12 +2,12 @@ import React from 'react' import { render, screen } from '@testing-library/react' import HistoryPanel from './history-panel' -const mockDocLink = jest.fn(() => 'doc-link') -jest.mock('@/context/i18n', () => ({ +const mockDocLink = vi.fn(() => 'doc-link') +vi.mock('@/context/i18n', () => ({ useDocLink: () => mockDocLink, })) -jest.mock('@/app/components/app/configuration/base/operation-btn', () => ({ +vi.mock('@/app/components/app/configuration/base/operation-btn', () => ({ __esModule: true, default: ({ onClick }: { onClick: () => void }) => ( + +
+ ) + }, +})) + +const createAgentConfig = (overrides: Partial = {}): AgentConfig => ({ + enabled: true, + strategy: AgentStrategy.react, + max_iteration: 3, + tools: [], + ...overrides, +}) + +const setup = (overrides: Partial> = {}) => { + const props: React.ComponentProps = { + isFunctionCall: false, + isChatModel: true, + onAgentSettingChange: vi.fn(), + agentConfig: createAgentConfig(), + ...overrides, + } + + const user = userEvent.setup() + render() + return { props, user } +} + +beforeEach(() => { + vi.clearAllMocks() + latestAgentSettingProps = undefined +}) + +describe('AgentSettingButton', () => { + it('should render button label from translation key', () => { + setup() + + expect(screen.getByRole('button', { name: 'appDebug.agent.setting.name' })).toBeInTheDocument() + }) + + it('should open AgentSetting with the provided configuration when clicked', async () => { + const { user, props } = setup({ isFunctionCall: true, isChatModel: false }) + + await user.click(screen.getByRole('button', { name: 'appDebug.agent.setting.name' })) + + expect(screen.getByTestId('agent-setting')).toBeInTheDocument() + expect(latestAgentSettingProps.isFunctionCall).toBe(true) + expect(latestAgentSettingProps.isChatModel).toBe(false) + expect(latestAgentSettingProps.payload).toEqual(props.agentConfig) + }) + + it('should call onAgentSettingChange and close when AgentSetting saves', async () => { + const { user, props } = setup() + + await user.click(screen.getByRole('button', { name: 'appDebug.agent.setting.name' })) + await user.click(screen.getByText('save-agent')) + + expect(props.onAgentSettingChange).toHaveBeenCalledTimes(1) + expect(props.onAgentSettingChange).toHaveBeenCalledWith({ + ...props.agentConfig, + max_iteration: 9, + }) + expect(screen.queryByTestId('agent-setting')).not.toBeInTheDocument() + }) + + it('should close AgentSetting without saving when cancel is triggered', async () => { + const { user, props } = setup() + + await user.click(screen.getByRole('button', { name: 'appDebug.agent.setting.name' })) + await user.click(screen.getByText('cancel-agent')) + + expect(props.onAgentSettingChange).not.toHaveBeenCalled() + expect(screen.queryByTestId('agent-setting')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx new file mode 100644 index 0000000000..c76ede41e8 --- /dev/null +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx @@ -0,0 +1,108 @@ +import React from 'react' +import { act, fireEvent, render, screen } from '@testing-library/react' +import AgentSetting from './index' +import { MAX_ITERATIONS_NUM } from '@/config' +import type { AgentConfig } from '@/models/debug' + +vi.mock('ahooks', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useClickAway: vi.fn(), + } +}) + +vi.mock('react-slider', () => ({ + default: (props: { className?: string; min?: number; max?: number; value: number; onChange: (value: number) => void }) => ( + props.onChange(Number(e.target.value))} + /> + ), +})) + +const basePayload = { + enabled: true, + strategy: 'react', + max_iteration: 5, + tools: [], +} + +const renderModal = (props?: Partial>) => { + const onCancel = vi.fn() + const onSave = vi.fn() + const utils = render( + , + ) + return { ...utils, onCancel, onSave } +} + +describe('AgentSetting', () => { + test('should render agent mode description and default prompt section when not function call', () => { + renderModal() + + expect(screen.getByText('appDebug.agent.agentMode')).toBeInTheDocument() + expect(screen.getByText('appDebug.agent.agentModeType.ReACT')).toBeInTheDocument() + expect(screen.getByText('tools.builtInPromptTitle')).toBeInTheDocument() + }) + + test('should display function call mode when isFunctionCall true', () => { + renderModal({ isFunctionCall: true }) + + expect(screen.getByText('appDebug.agent.agentModeType.functionCall')).toBeInTheDocument() + expect(screen.queryByText('tools.builtInPromptTitle')).not.toBeInTheDocument() + }) + + test('should update iteration via slider and number input', () => { + const { container } = renderModal() + const slider = container.querySelector('.slider') as HTMLInputElement + const numberInput = screen.getByRole('spinbutton') + + fireEvent.change(slider, { target: { value: '7' } }) + expect(screen.getAllByDisplayValue('7')).toHaveLength(2) + + fireEvent.change(numberInput, { target: { value: '2' } }) + expect(screen.getAllByDisplayValue('2')).toHaveLength(2) + }) + + test('should clamp iteration value within min/max range', () => { + renderModal() + + const numberInput = screen.getByRole('spinbutton') + + fireEvent.change(numberInput, { target: { value: '0' } }) + expect(screen.getAllByDisplayValue('1')).toHaveLength(2) + + fireEvent.change(numberInput, { target: { value: '999' } }) + expect(screen.getAllByDisplayValue(String(MAX_ITERATIONS_NUM))).toHaveLength(2) + }) + + test('should call onCancel when cancel button clicked', () => { + const { onCancel } = renderModal() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + expect(onCancel).toHaveBeenCalled() + }) + + test('should call onSave with updated payload', async () => { + const { onSave } = renderModal() + const numberInput = screen.getByRole('spinbutton') + fireEvent.change(numberInput, { target: { value: '6' } }) + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + }) + + expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ max_iteration: 6 })) + }) +}) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/item-panel.spec.tsx b/web/app/components/app/configuration/config/agent/agent-setting/item-panel.spec.tsx new file mode 100644 index 0000000000..242f249738 --- /dev/null +++ b/web/app/components/app/configuration/config/agent/agent-setting/item-panel.spec.tsx @@ -0,0 +1,21 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import ItemPanel from './item-panel' + +describe('AgentSetting/ItemPanel', () => { + test('should render icon, name, and children content', () => { + render( + icon} + name="Panel name" + description="More info" + children={
child content
} + />, + ) + + expect(screen.getByText('Panel name')).toBeInTheDocument() + expect(screen.getByText('child content')).toBeInTheDocument() + expect(screen.getByText('icon')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx b/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx index 6512e11545..6193392026 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Tooltip from '@/app/components/base/tooltip' type Props = { className?: string diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.spec.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.spec.tsx new file mode 100644 index 0000000000..f4ef5f050b --- /dev/null +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.spec.tsx @@ -0,0 +1,467 @@ +import type { Mock } from 'vitest' +import type { + PropsWithChildren, +} from 'react' +import React, { + useEffect, + useMemo, + useState, +} from 'react' +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AgentTools from './index' +import ConfigContext from '@/context/debug-configuration' +import type { AgentTool } from '@/types/app' +import { CollectionType, type Tool, type ToolParameter } from '@/app/components/tools/types' +import type { ToolWithProvider } from '@/app/components/workflow/types' +import type { ToolDefaultValue } from '@/app/components/workflow/block-selector/types' +import type { ModelConfig } from '@/models/debug' +import { ModelModeType } from '@/types/app' +import { + DEFAULT_AGENT_SETTING, + DEFAULT_CHAT_PROMPT_CONFIG, + DEFAULT_COMPLETION_PROMPT_CONFIG, +} from '@/config' +import copy from 'copy-to-clipboard' +import type ToolPickerType from '@/app/components/workflow/block-selector/tool-picker' +import type SettingBuiltInToolType from './setting-built-in-tool' + +const formattingDispatcherMock = vi.fn() +vi.mock('@/app/components/app/configuration/debug/hooks', () => ({ + useFormattingChangedDispatcher: () => formattingDispatcherMock, +})) + +let pluginInstallHandler: ((names: string[]) => void) | null = null +const subscribeMock = vi.fn((event: string, handler: any) => { + if (event === 'plugin:install:success') + pluginInstallHandler = handler +}) +vi.mock('@/context/mitt-context', () => ({ + useMittContextSelector: (selector: any) => selector({ + useSubscribe: subscribeMock, + }), +})) + +let builtInTools: ToolWithProvider[] = [] +let customTools: ToolWithProvider[] = [] +let workflowTools: ToolWithProvider[] = [] +let mcpTools: ToolWithProvider[] = [] +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: builtInTools }), + useAllCustomTools: () => ({ data: customTools }), + useAllWorkflowTools: () => ({ data: workflowTools }), + useAllMCPTools: () => ({ data: mcpTools }), +})) + +type ToolPickerProps = React.ComponentProps +let singleToolSelection: ToolDefaultValue | null = null +let multipleToolSelection: ToolDefaultValue[] = [] +const ToolPickerMock = (props: ToolPickerProps) => ( +
+
{props.trigger}
+ + +
+) +vi.mock('@/app/components/workflow/block-selector/tool-picker', () => ({ + __esModule: true, + default: (props: ToolPickerProps) => , +})) + +type SettingBuiltInToolProps = React.ComponentProps +let latestSettingPanelProps: SettingBuiltInToolProps | null = null +let settingPanelSavePayload: Record = {} +let settingPanelCredentialId = 'credential-from-panel' +const SettingBuiltInToolMock = (props: SettingBuiltInToolProps) => { + latestSettingPanelProps = props + return ( +
+ {props.toolName} + + + +
+ ) +} +vi.mock('./setting-built-in-tool', () => ({ + __esModule: true, + default: (props: SettingBuiltInToolProps) => , +})) + +vi.mock('copy-to-clipboard') + +const copyMock = copy as Mock + +const createToolParameter = (overrides?: Partial): ToolParameter => ({ + name: 'api_key', + label: { + en_US: 'API Key', + zh_Hans: 'API Key', + }, + human_description: { + en_US: 'desc', + zh_Hans: 'desc', + }, + type: 'string', + form: 'config', + llm_description: '', + required: true, + multiple: false, + default: 'default', + ...overrides, +}) + +const createToolDefinition = (overrides?: Partial): Tool => ({ + name: 'search', + author: 'tester', + label: { + en_US: 'Search', + zh_Hans: 'Search', + }, + description: { + en_US: 'desc', + zh_Hans: 'desc', + }, + parameters: [createToolParameter()], + labels: [], + output_schema: {}, + ...overrides, +}) + +const createCollection = (overrides?: Partial): ToolWithProvider => ({ + id: overrides?.id || 'provider-1', + name: overrides?.name || 'vendor/provider-1', + author: 'tester', + description: { + en_US: 'desc', + zh_Hans: 'desc', + }, + icon: 'https://example.com/icon.png', + label: { + en_US: 'Provider Label', + zh_Hans: 'Provider Label', + }, + type: overrides?.type || CollectionType.builtIn, + team_credentials: {}, + is_team_authorization: true, + allow_delete: true, + labels: [], + tools: overrides?.tools || [createToolDefinition()], + meta: { + version: '1.0.0', + }, + ...overrides, +}) + +const createAgentTool = (overrides?: Partial): AgentTool => ({ + provider_id: overrides?.provider_id || 'provider-1', + provider_type: overrides?.provider_type || CollectionType.builtIn, + provider_name: overrides?.provider_name || 'vendor/provider-1', + tool_name: overrides?.tool_name || 'search', + tool_label: overrides?.tool_label || 'Search Tool', + tool_parameters: overrides?.tool_parameters || { api_key: 'key' }, + enabled: overrides?.enabled ?? true, + ...overrides, +}) + +const createModelConfig = (tools: AgentTool[]): ModelConfig => ({ + provider: 'OPENAI', + model_id: 'gpt-3.5-turbo', + mode: ModelModeType.chat, + configs: { + prompt_template: '', + prompt_variables: [], + }, + chat_prompt_config: DEFAULT_CHAT_PROMPT_CONFIG, + completion_prompt_config: DEFAULT_COMPLETION_PROMPT_CONFIG, + opening_statement: '', + more_like_this: null, + suggested_questions: [], + suggested_questions_after_answer: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null, + retriever_resource: null, + sensitive_word_avoidance: null, + annotation_reply: null, + external_data_tools: [], + system_parameters: { + audio_file_size_limit: 0, + file_size_limit: 0, + image_file_size_limit: 0, + video_file_size_limit: 0, + workflow_file_upload_limit: 0, + }, + dataSets: [], + agentConfig: { + ...DEFAULT_AGENT_SETTING, + tools, + }, +}) + +const renderAgentTools = (initialTools?: AgentTool[]) => { + const tools = initialTools ?? [createAgentTool()] + const modelConfigRef = { current: createModelConfig(tools) } + const Wrapper = ({ children }: PropsWithChildren) => { + const [modelConfig, setModelConfig] = useState(modelConfigRef.current) + useEffect(() => { + modelConfigRef.current = modelConfig + }, [modelConfig]) + const value = useMemo(() => ({ + modelConfig, + setModelConfig, + }), [modelConfig]) + return ( + + {children} + + ) + } + const renderResult = render( + + + , + ) + return { + ...renderResult, + getModelConfig: () => modelConfigRef.current, + } +} + +const hoverInfoIcon = async (rowIndex = 0) => { + const rows = document.querySelectorAll('.group') + const infoTrigger = rows.item(rowIndex)?.querySelector('[data-testid="tool-info-tooltip"]') + if (!infoTrigger) + throw new Error('Info trigger not found') + await userEvent.hover(infoTrigger as HTMLElement) +} + +describe('AgentTools', () => { + beforeEach(() => { + vi.clearAllMocks() + builtInTools = [ + createCollection(), + createCollection({ + id: 'provider-2', + name: 'vendor/provider-2', + tools: [createToolDefinition({ + name: 'translate', + label: { + en_US: 'Translate', + zh_Hans: 'Translate', + }, + })], + }), + createCollection({ + id: 'provider-3', + name: 'vendor/provider-3', + tools: [createToolDefinition({ + name: 'summarize', + label: { + en_US: 'Summary', + zh_Hans: 'Summary', + }, + })], + }), + ] + customTools = [] + workflowTools = [] + mcpTools = [] + singleToolSelection = { + provider_id: 'provider-3', + provider_type: CollectionType.builtIn, + provider_name: 'vendor/provider-3', + tool_name: 'summarize', + tool_label: 'Summary Tool', + tool_description: 'desc', + title: 'Summary Tool', + is_team_authorization: true, + params: { api_key: 'picker-value' }, + paramSchemas: [], + output_schema: {}, + } + multipleToolSelection = [ + { + provider_id: 'provider-2', + provider_type: CollectionType.builtIn, + provider_name: 'vendor/provider-2', + tool_name: 'translate', + tool_label: 'Translate Tool', + tool_description: 'desc', + title: 'Translate Tool', + is_team_authorization: true, + params: { api_key: 'multi-a' }, + paramSchemas: [], + output_schema: {}, + }, + { + provider_id: 'provider-3', + provider_type: CollectionType.builtIn, + provider_name: 'vendor/provider-3', + tool_name: 'summarize', + tool_label: 'Summary Tool', + tool_description: 'desc', + title: 'Summary Tool', + is_team_authorization: true, + params: { api_key: 'multi-b' }, + paramSchemas: [], + output_schema: {}, + }, + ] + latestSettingPanelProps = null + settingPanelSavePayload = {} + settingPanelCredentialId = 'credential-from-panel' + pluginInstallHandler = null + }) + + test('should show enabled count and provider information', () => { + renderAgentTools([ + createAgentTool(), + createAgentTool({ + provider_id: 'provider-2', + provider_name: 'vendor/provider-2', + tool_name: 'translate', + tool_label: 'Translate Tool', + enabled: false, + }), + ]) + + const enabledText = screen.getByText(content => content.includes('appDebug.agent.tools.enabled')) + expect(enabledText).toHaveTextContent('1/2') + expect(screen.getByText('provider-1')).toBeInTheDocument() + expect(screen.getByText('Translate Tool')).toBeInTheDocument() + }) + + test('should copy tool name from tooltip action', async () => { + renderAgentTools() + + await hoverInfoIcon() + const copyButton = await screen.findByText('tools.copyToolName') + await userEvent.click(copyButton) + expect(copyMock).toHaveBeenCalledWith('search') + }) + + test('should toggle tool enabled state via switch', async () => { + const { getModelConfig } = renderAgentTools() + + const switchButton = screen.getByRole('switch') + await userEvent.click(switchButton) + + await waitFor(() => { + const tools = getModelConfig().agentConfig.tools as Array<{ tool_name?: string; enabled?: boolean }> + const toggledTool = tools.find(tool => tool.tool_name === 'search') + expect(toggledTool?.enabled).toBe(false) + }) + expect(formattingDispatcherMock).toHaveBeenCalled() + }) + + test('should remove tool when delete action is clicked', async () => { + const { getModelConfig } = renderAgentTools() + const deleteButton = screen.getByTestId('delete-removed-tool') + if (!deleteButton) + throw new Error('Delete button not found') + await userEvent.click(deleteButton) + await waitFor(() => { + expect(getModelConfig().agentConfig.tools).toHaveLength(0) + }) + expect(formattingDispatcherMock).toHaveBeenCalled() + }) + + test('should add a tool when ToolPicker selects one', async () => { + const { getModelConfig } = renderAgentTools([]) + const addSingleButton = screen.getByRole('button', { name: 'pick-single' }) + await userEvent.click(addSingleButton) + + await waitFor(() => { + expect(screen.getByText('Summary Tool')).toBeInTheDocument() + }) + expect(getModelConfig().agentConfig.tools).toHaveLength(1) + }) + + test('should append multiple selected tools at once', async () => { + const { getModelConfig } = renderAgentTools([]) + await userEvent.click(screen.getByRole('button', { name: 'pick-multiple' })) + + await waitFor(() => { + expect(screen.getByText('Translate Tool')).toBeInTheDocument() + expect(screen.getAllByText('Summary Tool')).toHaveLength(1) + }) + expect(getModelConfig().agentConfig.tools).toHaveLength(2) + }) + + test('should open settings panel for not authorized tool', async () => { + renderAgentTools([ + createAgentTool({ + notAuthor: true, + }), + ]) + + const notAuthorizedButton = screen.getByRole('button', { name: /tools.notAuthorized/ }) + await userEvent.click(notAuthorizedButton) + expect(screen.getByTestId('setting-built-in-tool')).toBeInTheDocument() + expect(latestSettingPanelProps?.toolName).toBe('search') + }) + + test('should persist tool parameters when SettingBuiltInTool saves values', async () => { + const { getModelConfig } = renderAgentTools([ + createAgentTool({ + notAuthor: true, + }), + ]) + await userEvent.click(screen.getByRole('button', { name: /tools.notAuthorized/ })) + settingPanelSavePayload = { api_key: 'updated' } + await userEvent.click(screen.getByRole('button', { name: 'save-from-panel' })) + + await waitFor(() => { + expect((getModelConfig().agentConfig.tools[0] as { tool_parameters: Record }).tool_parameters).toEqual({ api_key: 'updated' }) + }) + }) + + test('should update credential id when authorization selection changes', async () => { + const { getModelConfig } = renderAgentTools([ + createAgentTool({ + notAuthor: true, + }), + ]) + await userEvent.click(screen.getByRole('button', { name: /tools.notAuthorized/ })) + settingPanelCredentialId = 'credential-123' + await userEvent.click(screen.getByRole('button', { name: 'auth-from-panel' })) + + await waitFor(() => { + expect((getModelConfig().agentConfig.tools[0] as { credential_id: string }).credential_id).toBe('credential-123') + }) + expect(formattingDispatcherMock).toHaveBeenCalled() + }) + + test('should reinstate deleted tools after plugin install success event', async () => { + const { getModelConfig } = renderAgentTools([ + createAgentTool({ + provider_id: 'provider-1', + provider_name: 'vendor/provider-1', + tool_name: 'search', + tool_label: 'Search Tool', + isDeleted: true, + }), + ]) + if (!pluginInstallHandler) + throw new Error('Plugin handler not registered') + + await act(async () => { + pluginInstallHandler?.(['provider-1']) + }) + + await waitFor(() => { + expect((getModelConfig().agentConfig.tools[0] as { isDeleted: boolean }).isDeleted).toBe(false) + }) + }) +}) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 5716bfd92d..8dfa2f194b 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -25,7 +25,7 @@ import { MAX_TOOLS_NUM } from '@/config' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import Tooltip from '@/app/components/base/tooltip' import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import ToolPicker from '@/app/components/workflow/block-selector/tool-picker' import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types' import { canFindTool } from '@/utils' @@ -217,7 +217,7 @@ const AgentTools: FC = () => { } >
-
+
@@ -277,6 +277,7 @@ const AgentTools: FC = () => { }} onMouseOver={() => setIsDeleting(index)} onMouseLeave={() => setIsDeleting(-1)} + data-testid='delete-removed-tool' >
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.spec.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.spec.tsx new file mode 100644 index 0000000000..4d82c29cdc --- /dev/null +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.spec.tsx @@ -0,0 +1,248 @@ +import React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SettingBuiltInTool from './setting-built-in-tool' +import I18n from '@/context/i18n' +import { CollectionType, type Tool, type ToolParameter } from '@/app/components/tools/types' + +const fetchModelToolList = vi.fn() +const fetchBuiltInToolList = vi.fn() +const fetchCustomToolList = vi.fn() +const fetchWorkflowToolList = vi.fn() +vi.mock('@/service/tools', () => ({ + fetchModelToolList: (collectionName: string) => fetchModelToolList(collectionName), + fetchBuiltInToolList: (collectionName: string) => fetchBuiltInToolList(collectionName), + fetchCustomToolList: (collectionName: string) => fetchCustomToolList(collectionName), + fetchWorkflowToolList: (appId: string) => fetchWorkflowToolList(appId), +})) + +type MockFormProps = { + value: Record + onChange: (val: Record) => void +} +let nextFormValue: Record = {} +const FormMock = ({ value, onChange }: MockFormProps) => { + return ( +
+
{JSON.stringify(value)}
+ +
+ ) +} +vi.mock('@/app/components/header/account-setting/model-provider-page/model-modal/Form', () => ({ + __esModule: true, + default: (props: MockFormProps) => , +})) + +let pluginAuthClickValue = 'credential-from-plugin' +vi.mock('@/app/components/plugins/plugin-auth', () => ({ + AuthCategory: { tool: 'tool' }, + PluginAuthInAgent: (props: { onAuthorizationItemClick?: (id: string) => void }) => ( +
+ +
+ ), +})) + +vi.mock('@/app/components/plugins/readme-panel/entrance', () => ({ + ReadmeEntrance: ({ className }: { className?: string }) =>
readme
, +})) + +const createParameter = (overrides?: Partial): ToolParameter => ({ + name: 'settingParam', + label: { + en_US: 'Setting Param', + zh_Hans: 'Setting Param', + }, + human_description: { + en_US: 'desc', + zh_Hans: 'desc', + }, + type: 'string', + form: 'config', + llm_description: '', + required: true, + multiple: false, + default: '', + ...overrides, +}) + +const createTool = (overrides?: Partial): Tool => ({ + name: 'search', + author: 'tester', + label: { + en_US: 'Search Tool', + zh_Hans: 'Search Tool', + }, + description: { + en_US: 'tool description', + zh_Hans: 'tool description', + }, + parameters: [ + createParameter({ + name: 'infoParam', + label: { + en_US: 'Info Param', + zh_Hans: 'Info Param', + }, + form: 'llm', + required: false, + }), + createParameter(), + ], + labels: [], + output_schema: {}, + ...overrides, +}) + +const baseCollection = { + id: 'provider-1', + name: 'vendor/provider-1', + author: 'tester', + description: { + en_US: 'desc', + zh_Hans: 'desc', + }, + icon: 'https://example.com/icon.png', + label: { + en_US: 'Provider Label', + zh_Hans: 'Provider Label', + }, + type: CollectionType.builtIn, + team_credentials: {}, + is_team_authorization: true, + allow_delete: true, + labels: [], + tools: [createTool()], +} + +const renderComponent = (props?: Partial>) => { + const onHide = vi.fn() + const onSave = vi.fn() + const onAuthorizationItemClick = vi.fn() + const utils = render( + + + , + ) + return { + ...utils, + onHide, + onSave, + onAuthorizationItemClick, + } +} + +describe('SettingBuiltInTool', () => { + beforeEach(() => { + vi.clearAllMocks() + nextFormValue = {} + pluginAuthClickValue = 'credential-from-plugin' + }) + + test('should fetch tool list when collection has no tools', async () => { + fetchModelToolList.mockResolvedValueOnce([createTool()]) + renderComponent({ + collection: { + ...baseCollection, + tools: [], + }, + }) + + await waitFor(() => { + expect(fetchModelToolList).toHaveBeenCalledTimes(1) + expect(fetchModelToolList).toHaveBeenCalledWith('vendor/provider-1') + }) + expect(await screen.findByText('Search Tool')).toBeInTheDocument() + }) + + test('should switch between info and setting tabs', async () => { + renderComponent() + await waitFor(() => { + expect(screen.getByTestId('mock-form')).toBeInTheDocument() + }) + + await userEvent.click(screen.getByText('tools.setBuiltInTools.parameters')) + expect(screen.getByText('Info Param')).toBeInTheDocument() + await userEvent.click(screen.getByText('tools.setBuiltInTools.setting')) + expect(screen.getByTestId('mock-form')).toBeInTheDocument() + }) + + test('should call onSave with updated values when save button clicked', async () => { + const { onSave } = renderComponent() + await waitFor(() => expect(screen.getByTestId('mock-form')).toBeInTheDocument()) + nextFormValue = { settingParam: 'updated' } + await userEvent.click(screen.getByRole('button', { name: 'update-form' })) + await userEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ settingParam: 'updated' })) + }) + + test('should keep save disabled until required field provided', async () => { + renderComponent({ + setting: {}, + }) + await waitFor(() => expect(screen.getByTestId('mock-form')).toBeInTheDocument()) + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + expect(saveButton).toBeDisabled() + nextFormValue = { settingParam: 'filled' } + await userEvent.click(screen.getByRole('button', { name: 'update-form' })) + expect(saveButton).not.toBeDisabled() + }) + + test('should call onHide when cancel button is pressed', async () => { + const { onHide } = renderComponent() + await waitFor(() => expect(screen.getByTestId('mock-form')).toBeInTheDocument()) + await userEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + expect(onHide).toHaveBeenCalled() + }) + + test('should trigger authorization callback from plugin auth section', async () => { + const { onAuthorizationItemClick } = renderComponent() + await userEvent.click(screen.getByRole('button', { name: 'choose-plugin-credential' })) + expect(onAuthorizationItemClick).toHaveBeenCalledWith('credential-from-plugin') + }) + + test('should call onHide when back button is clicked', async () => { + const { onHide } = renderComponent({ + showBackButton: true, + }) + await userEvent.click(screen.getByText('plugin.detailPanel.operation.back')) + expect(onHide).toHaveBeenCalled() + }) + + test('should load workflow tools when workflow collection is provided', async () => { + fetchWorkflowToolList.mockResolvedValueOnce([createTool({ + name: 'workflow-tool', + })]) + renderComponent({ + collection: { + ...baseCollection, + type: CollectionType.workflow, + tools: [], + id: 'workflow-1', + } as any, + isBuiltIn: false, + isModel: false, + }) + + await waitFor(() => { + expect(fetchWorkflowToolList).toHaveBeenCalledWith('workflow-1') + }) + }) +}) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index c5947495db..0627666b4c 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -22,7 +22,7 @@ import { CollectionType } from '@/app/components/tools/types' import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList, fetchWorkflowToolList } from '@/service/tools' import I18n from '@/context/i18n' import { getLanguage } from '@/i18n-config/language' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { ToolWithProvider } from '@/app/components/workflow/types' import { AuthCategory, diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index 71a9304d0c..78d7eef029 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -4,7 +4,7 @@ import React from 'react' import copy from 'copy-to-clipboard' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Copy, CopyCheck, diff --git a/web/app/components/app/configuration/config/assistant-type-picker/index.spec.tsx b/web/app/components/app/configuration/config/assistant-type-picker/index.spec.tsx new file mode 100644 index 0000000000..e17da4e58e --- /dev/null +++ b/web/app/components/app/configuration/config/assistant-type-picker/index.spec.tsx @@ -0,0 +1,865 @@ +import React from 'react' +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AssistantTypePicker from './index' +import type { AgentConfig } from '@/models/debug' +import { AgentStrategy } from '@/types/app' + +// Test utilities +const defaultAgentConfig: AgentConfig = { + enabled: true, + max_iteration: 3, + strategy: AgentStrategy.functionCall, + tools: [], +} + +const defaultProps = { + value: 'chat', + disabled: false, + onChange: vi.fn(), + isFunctionCall: true, + isChatModel: true, + agentConfig: defaultAgentConfig, + onAgentSettingChange: vi.fn(), +} + +const renderComponent = (props: Partial> = {}) => { + const mergedProps = { ...defaultProps, ...props } + return render() +} + +// Helper to get option element by description (which is unique per option) +const getOptionByDescription = (descriptionRegex: RegExp) => { + const description = screen.getByText(descriptionRegex) + return description.parentElement as HTMLElement +} + +describe('AssistantTypePicker', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + + it('should render chat assistant by default when value is "chat"', () => { + // Arrange & Act + renderComponent({ value: 'chat' }) + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + + it('should render agent assistant when value is "agent"', () => { + // Arrange & Act + renderComponent({ value: 'agent' }) + + // Assert + expect(screen.getByText(/agentAssistant.name/i)).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should use provided value prop', () => { + // Arrange & Act + renderComponent({ value: 'agent' }) + + // Assert + expect(screen.getByText(/agentAssistant.name/i)).toBeInTheDocument() + }) + + it('should handle agentConfig prop', () => { + // Arrange + const customAgentConfig: AgentConfig = { + enabled: true, + max_iteration: 10, + strategy: AgentStrategy.react, + tools: [], + } + + // Act + expect(() => { + renderComponent({ agentConfig: customAgentConfig }) + }).not.toThrow() + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + + it('should handle undefined agentConfig prop', () => { + // Arrange & Act + expect(() => { + renderComponent({ agentConfig: undefined }) + }).not.toThrow() + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should open dropdown when clicking trigger', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Both options should be visible + await waitFor(() => { + const chatOptions = screen.getAllByText(/chatAssistant.name/i) + const agentOptions = screen.getAllByText(/agentAssistant.name/i) + expect(chatOptions.length).toBeGreaterThan(1) + expect(agentOptions.length).toBeGreaterThan(0) + }) + }) + + it('should call onChange when selecting chat assistant', async () => { + // Arrange + const user = userEvent.setup() + const onChange = vi.fn() + renderComponent({ value: 'agent', onChange }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown to open and find chat option + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + // Find and click the chat option by its unique description + const chatOption = getOptionByDescription(/chatAssistant.description/i) + await user.click(chatOption) + + // Assert + expect(onChange).toHaveBeenCalledWith('chat') + }) + + it('should call onChange when selecting agent assistant', async () => { + // Arrange + const user = userEvent.setup() + const onChange = vi.fn() + renderComponent({ value: 'chat', onChange }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown to open and click agent option + await waitFor(() => { + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + const agentOption = getOptionByDescription(/agentAssistant.description/i) + await user.click(agentOption) + + // Assert + expect(onChange).toHaveBeenCalledWith('agent') + }) + + it('should close dropdown when selecting chat assistant', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent' }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown and select chat + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + const chatOption = getOptionByDescription(/chatAssistant.description/i) + await user.click(chatOption) + + // Assert - Dropdown should close (descriptions should not be visible) + await waitFor(() => { + expect(screen.queryByText(/chatAssistant.description/i)).not.toBeInTheDocument() + }) + }) + + it('should not close dropdown when selecting agent assistant', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'chat' }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown and select agent + await waitFor(() => { + const agentOptions = screen.getAllByText(/agentAssistant.name/i) + expect(agentOptions.length).toBeGreaterThan(0) + }) + + const agentOptions = screen.getAllByText(/agentAssistant.name/i) + await user.click(agentOptions[0]) + + // Assert - Dropdown should remain open (agent settings should be visible) + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + }) + + it('should not call onChange when clicking same value', async () => { + // Arrange + const user = userEvent.setup() + const onChange = vi.fn() + renderComponent({ value: 'chat', onChange }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown and click same option + await waitFor(() => { + const chatOptions = screen.getAllByText(/chatAssistant.name/i) + expect(chatOptions.length).toBeGreaterThan(1) + }) + + const chatOptions = screen.getAllByText(/chatAssistant.name/i) + await user.click(chatOptions[1]) + + // Assert + expect(onChange).not.toHaveBeenCalled() + }) + }) + + // Disabled state + describe('Disabled State', () => { + it('should not respond to clicks when disabled', async () => { + // Arrange + const user = userEvent.setup() + const onChange = vi.fn() + renderComponent({ disabled: true, onChange }) + + // Act - Open dropdown (dropdown can still open when disabled) + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown to open + await waitFor(() => { + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // Act - Try to click an option + const agentOption = getOptionByDescription(/agentAssistant.description/i) + await user.click(agentOption) + + // Assert - onChange should not be called (options are disabled) + expect(onChange).not.toHaveBeenCalled() + }) + + it('should not show agent config UI when disabled', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: true }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Assert - Agent settings option should not be visible + await waitFor(() => { + expect(screen.queryByText(/agent.setting.name/i)).not.toBeInTheDocument() + }) + }) + + it('should show agent config UI when not disabled', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: false }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Assert - Agent settings option should be visible + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + }) + }) + + // Agent Settings Modal + describe('Agent Settings Modal', () => { + it('should open agent settings modal when clicking agent config UI', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: false }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Click agent settings + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + }) + + it('should not open agent settings when value is not agent', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'chat', disabled: false }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown to open + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + // Assert - Agent settings modal should not appear (value is 'chat') + expect(screen.queryByText(/common.operation.save/i)).not.toBeInTheDocument() + }) + + it('should call onAgentSettingChange when saving agent settings', async () => { + // Arrange + const user = userEvent.setup() + const onAgentSettingChange = vi.fn() + renderComponent({ value: 'agent', disabled: false, onAgentSettingChange }) + + // Act - Open dropdown and agent settings + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Wait for modal and click save + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + + const saveButton = screen.getByText(/common.operation.save/i) + await user.click(saveButton) + + // Assert + expect(onAgentSettingChange).toHaveBeenCalledWith(defaultAgentConfig) + }) + + it('should close modal when saving agent settings', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: false }) + + // Act - Open dropdown, agent settings, and save + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + await waitFor(() => { + expect(screen.getByText(/appDebug.agent.setting.name/i)).toBeInTheDocument() + }) + + const saveButton = screen.getByText(/common.operation.save/i) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(screen.queryByText(/common.operation.save/i)).not.toBeInTheDocument() + }) + }) + + it('should close modal when canceling agent settings', async () => { + // Arrange + const user = userEvent.setup() + const onAgentSettingChange = vi.fn() + renderComponent({ value: 'agent', disabled: false, onAgentSettingChange }) + + // Act - Open dropdown, agent settings, and cancel + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + + const cancelButton = screen.getByText(/common.operation.cancel/i) + await user.click(cancelButton) + + // Assert + await waitFor(() => { + expect(screen.queryByText(/common.operation.save/i)).not.toBeInTheDocument() + }) + expect(onAgentSettingChange).not.toHaveBeenCalled() + }) + + it('should close dropdown when opening agent settings', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: false }) + + // Act - Open dropdown and agent settings + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert - Modal should be open and dropdown should close + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + + // The dropdown should be closed (agent settings description should not be visible) + await waitFor(() => { + const descriptions = screen.queryAllByText(/agent.setting.description/i) + expect(descriptions.length).toBe(0) + }) + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle rapid toggle clicks', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + await user.click(trigger) + await user.click(trigger) + + // Assert - Should not crash + expect(trigger).toBeInTheDocument() + }) + + it('should handle multiple rapid selection changes', async () => { + // Arrange + const user = userEvent.setup() + const onChange = vi.fn() + renderComponent({ value: 'chat', onChange }) + + // Act - Open and select agent + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // Click agent option - this stays open because value is 'agent' + const agentOption = getOptionByDescription(/agentAssistant.description/i) + await user.click(agentOption) + + // Assert - onChange should have been called once to switch to agent + await waitFor(() => { + expect(onChange).toHaveBeenCalledTimes(1) + }) + expect(onChange).toHaveBeenCalledWith('agent') + }) + + it('should handle missing callback functions gracefully', async () => { + // Arrange + const user = userEvent.setup() + + // Act & Assert - Should not crash + expect(() => { + renderComponent({ + onChange: undefined!, + onAgentSettingChange: undefined!, + }) + }).not.toThrow() + + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + }) + + it('should handle empty agentConfig', async () => { + // Arrange & Act + expect(() => { + renderComponent({ agentConfig: {} as AgentConfig }) + }).not.toThrow() + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + + describe('should render with different prop combinations', () => { + const combinations = [ + { value: 'chat' as const, disabled: true, isFunctionCall: true, isChatModel: true }, + { value: 'agent' as const, disabled: false, isFunctionCall: false, isChatModel: false }, + { value: 'agent' as const, disabled: true, isFunctionCall: true, isChatModel: false }, + { value: 'chat' as const, disabled: false, isFunctionCall: false, isChatModel: true }, + ] + + it.each(combinations)( + 'value=$value, disabled=$disabled, isFunctionCall=$isFunctionCall, isChatModel=$isChatModel', + (combo) => { + // Arrange & Act + renderComponent(combo) + + // Assert + const expectedText = combo.value === 'agent' ? 'agentAssistant.name' : 'chatAssistant.name' + expect(screen.getByText(new RegExp(expectedText, 'i'))).toBeInTheDocument() + }, + ) + }) + }) + + // Accessibility + describe('Accessibility', () => { + it('should render interactive dropdown items', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Both options should be visible and clickable + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // Verify we can interact with option elements using helper function + const chatOption = getOptionByDescription(/chatAssistant.description/i) + const agentOption = getOptionByDescription(/agentAssistant.description/i) + expect(chatOption).toBeInTheDocument() + expect(agentOption).toBeInTheDocument() + }) + }) + + // SelectItem Component + describe('SelectItem Component', () => { + it('should show checked state for selected option', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'chat' }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Both options should be visible with radio components + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // The SelectItem components render with different visual states + // based on isChecked prop - we verify both options are rendered + const chatOption = getOptionByDescription(/chatAssistant.description/i) + const agentOption = getOptionByDescription(/agentAssistant.description/i) + expect(chatOption).toBeInTheDocument() + expect(agentOption).toBeInTheDocument() + }) + + it('should render description text', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Descriptions should be visible + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + }) + + it('should show Radio component for each option', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Radio components should be present (both options visible) + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + }) + }) + + // Agent Setting Integration + describe('AgentSetting Integration', () => { + it('should show function call mode when isFunctionCall is true', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', isFunctionCall: true, isChatModel: false }) + + // Act - Open dropdown and settings modal + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + expect(screen.getByText(/appDebug.agent.agentModeType.functionCall/i)).toBeInTheDocument() + }) + + it('should show built-in prompt when isFunctionCall is false', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', isFunctionCall: false, isChatModel: true }) + + // Act - Open dropdown and settings modal + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + expect(screen.getByText(/tools.builtInPromptTitle/i)).toBeInTheDocument() + }) + + it('should initialize max iteration from agentConfig payload', async () => { + // Arrange + const user = userEvent.setup() + const customConfig: AgentConfig = { + enabled: true, + max_iteration: 10, + strategy: AgentStrategy.react, + tools: [], + } + + renderComponent({ value: 'agent', agentConfig: customConfig }) + + // Act - Open dropdown and settings modal + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert + await screen.findByText(/common.operation.save/i) + const maxIterationInput = await screen.findByRole('spinbutton') + expect(maxIterationInput).toHaveValue(10) + }) + }) + + // Keyboard Navigation + describe('Keyboard Navigation', () => { + it('should support closing dropdown with Escape key', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + // Press Escape + await user.keyboard('{Escape}') + + // Assert - Dropdown should close + await waitFor(() => { + expect(screen.queryByText(/chatAssistant.description/i)).not.toBeInTheDocument() + }) + }) + + it('should allow keyboard focus on trigger element', () => { + // Arrange + renderComponent() + + // Act - Get trigger and verify it can receive focus + const trigger = screen.getByText(/chatAssistant.name/i) + + // Assert - Element should be focusable + expect(trigger).toBeInTheDocument() + expect(trigger.parentElement).toBeInTheDocument() + }) + + it('should allow keyboard focus on dropdown options', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + // Get options + const chatOption = getOptionByDescription(/chatAssistant.description/i) + const agentOption = getOptionByDescription(/agentAssistant.description/i) + + // Assert - Options should be focusable + expect(chatOption).toBeInTheDocument() + expect(agentOption).toBeInTheDocument() + + // Verify options exist and can receive focus programmatically + // Note: focus() doesn't always update document.activeElement in JSDOM + // so we just verify the elements are interactive + act(() => { + chatOption.focus() + }) + // The element should have received the focus call even if activeElement isn't updated + expect(chatOption.tabIndex).toBeDefined() + }) + + it('should maintain keyboard accessibility for all interactive elements', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent' }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Assert - Agent settings button should be focusable + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettings = screen.getByText(/agent.setting.name/i) + expect(agentSettings).toBeInTheDocument() + }) + }) + + // ARIA Attributes + describe('ARIA Attributes', () => { + it('should have proper ARIA state for dropdown', async () => { + // Arrange + const user = userEvent.setup() + const { container } = renderComponent() + + // Act - Check initial state + const portalContainer = container.querySelector('[data-state]') + expect(portalContainer).toHaveAttribute('data-state', 'closed') + + // Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - State should change to open + await waitFor(() => { + const openPortal = container.querySelector('[data-state="open"]') + expect(openPortal).toBeInTheDocument() + }) + }) + + it('should have proper data-state attribute', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert - Portal should have data-state for accessibility + const portalContainer = container.querySelector('[data-state]') + expect(portalContainer).toBeInTheDocument() + expect(portalContainer).toHaveAttribute('data-state') + + // Should start in closed state + expect(portalContainer).toHaveAttribute('data-state', 'closed') + }) + + it('should maintain accessible structure for screen readers', () => { + // Arrange & Act + renderComponent({ value: 'chat' }) + + // Assert - Text content should be accessible + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + + // Icons should have proper structure + const { container } = renderComponent() + const icons = container.querySelectorAll('svg') + expect(icons.length).toBeGreaterThan(0) + }) + + it('should provide context through text labels', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - All options should have descriptive text + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // Title text should be visible + expect(screen.getByText(/assistantType.name/i)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/config/assistant-type-picker/index.tsx b/web/app/components/app/configuration/config/assistant-type-picker/index.tsx index 3597a6e292..50f16f957a 100644 --- a/web/app/components/app/configuration/config/assistant-type-picker/index.tsx +++ b/web/app/components/app/configuration/config/assistant-type-picker/index.tsx @@ -4,7 +4,7 @@ import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { RiArrowDownSLine } from '@remixicon/react' import AgentSetting from '../agent/agent-setting' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, diff --git a/web/app/components/app/configuration/config/automatic/idea-output.tsx b/web/app/components/app/configuration/config/automatic/idea-output.tsx index df4f76c92b..895f74baa3 100644 --- a/web/app/components/app/configuration/config/automatic/idea-output.tsx +++ b/web/app/components/app/configuration/config/automatic/idea-output.tsx @@ -3,7 +3,7 @@ import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid import { useBoolean } from 'ahooks' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Textarea from '@/app/components/base/textarea' import { useTranslation } from 'react-i18next' diff --git a/web/app/components/app/configuration/config/automatic/instruction-editor.tsx b/web/app/components/app/configuration/config/automatic/instruction-editor.tsx index b14ee93313..409f335232 100644 --- a/web/app/components/app/configuration/config/automatic/instruction-editor.tsx +++ b/web/app/components/app/configuration/config/automatic/instruction-editor.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React from 'react' import PromptEditor from '@/app/components/base/prompt-editor' import type { GeneratorType } from './types' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { Node, NodeOutPutVar, ValueSelector } from '@/app/components/workflow/types' import { BlockEnum } from '@/app/components/workflow/types' import { useTranslation } from 'react-i18next' diff --git a/web/app/components/app/configuration/config/automatic/prompt-toast.tsx b/web/app/components/app/configuration/config/automatic/prompt-toast.tsx index 2826cc97c8..c9169f0ad7 100644 --- a/web/app/components/app/configuration/config/automatic/prompt-toast.tsx +++ b/web/app/components/app/configuration/config/automatic/prompt-toast.tsx @@ -1,7 +1,7 @@ import { RiArrowDownSLine, RiSparklingFill } from '@remixicon/react' import { useBoolean } from 'ahooks' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Markdown } from '@/app/components/base/markdown' import { useTranslation } from 'react-i18next' import s from './style.module.css' diff --git a/web/app/components/app/configuration/config/automatic/version-selector.tsx b/web/app/components/app/configuration/config/automatic/version-selector.tsx index c3d3e1d91c..715c1f3c80 100644 --- a/web/app/components/app/configuration/config/automatic/version-selector.tsx +++ b/web/app/components/app/configuration/config/automatic/version-selector.tsx @@ -1,7 +1,7 @@ import React, { useCallback } from 'react' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' import { useBoolean } from 'ahooks' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' diff --git a/web/app/components/app/configuration/config/config-audio.spec.tsx b/web/app/components/app/configuration/config/config-audio.spec.tsx new file mode 100644 index 0000000000..132ada95d0 --- /dev/null +++ b/web/app/components/app/configuration/config/config-audio.spec.tsx @@ -0,0 +1,124 @@ +import type { Mock } from 'vitest' +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import ConfigAudio from './config-audio' +import type { FeatureStoreState } from '@/app/components/base/features/store' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' + +const mockUseContext = vi.fn() +vi.mock('use-context-selector', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useContext: (context: unknown) => mockUseContext(context), + } +}) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +const mockUseFeatures = vi.fn() +const mockUseFeaturesStore = vi.fn() +vi.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: (selector: (state: FeatureStoreState) => any) => mockUseFeatures(selector), + useFeaturesStore: () => mockUseFeaturesStore(), +})) + +type SetupOptions = { + isVisible?: boolean + allowedTypes?: SupportUploadFileTypes[] +} + +let mockFeatureStoreState: FeatureStoreState +let mockSetFeatures: Mock +const mockStore = { + getState: vi.fn<() => FeatureStoreState>(() => mockFeatureStoreState), +} + +const setupFeatureStore = (allowedTypes: SupportUploadFileTypes[] = []) => { + mockSetFeatures = vi.fn() + mockFeatureStoreState = { + features: { + file: { + allowed_file_types: allowedTypes, + enabled: allowedTypes.length > 0, + }, + }, + setFeatures: mockSetFeatures, + showFeaturesModal: false, + setShowFeaturesModal: vi.fn(), + } + mockStore.getState.mockImplementation(() => mockFeatureStoreState) + mockUseFeaturesStore.mockReturnValue(mockStore) + mockUseFeatures.mockImplementation(selector => selector(mockFeatureStoreState)) +} + +const renderConfigAudio = (options: SetupOptions = {}) => { + const { + isVisible = true, + allowedTypes = [], + } = options + setupFeatureStore(allowedTypes) + mockUseContext.mockReturnValue({ + isShowAudioConfig: isVisible, + }) + const user = userEvent.setup() + render() + return { + user, + setFeatures: mockSetFeatures, + } +} + +beforeEach(() => { + vi.clearAllMocks() +}) + +describe('ConfigAudio', () => { + it('should not render when the audio configuration is hidden', () => { + renderConfigAudio({ isVisible: false }) + + expect(screen.queryByText('appDebug.feature.audioUpload.title')).not.toBeInTheDocument() + }) + + it('should display the audio toggle state based on feature store data', () => { + renderConfigAudio({ allowedTypes: [SupportUploadFileTypes.audio] }) + + expect(screen.getByText('appDebug.feature.audioUpload.title')).toBeInTheDocument() + expect(screen.getByRole('switch')).toHaveAttribute('aria-checked', 'true') + }) + + it('should enable audio uploads when toggled on', async () => { + const { user, setFeatures } = renderConfigAudio() + const toggle = screen.getByRole('switch') + + expect(toggle).toHaveAttribute('aria-checked', 'false') + await user.click(toggle) + + expect(setFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ + allowed_file_types: [SupportUploadFileTypes.audio], + enabled: true, + }), + })) + }) + + it('should disable audio uploads and turn off file feature when last type is removed', async () => { + const { user, setFeatures } = renderConfigAudio({ allowedTypes: [SupportUploadFileTypes.audio] }) + const toggle = screen.getByRole('switch') + + expect(toggle).toHaveAttribute('aria-checked', 'true') + await user.click(toggle) + + expect(setFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ + allowed_file_types: [], + enabled: false, + }), + })) + }) +}) diff --git a/web/app/components/app/configuration/config/config-document.spec.tsx b/web/app/components/app/configuration/config/config-document.spec.tsx new file mode 100644 index 0000000000..c351b5f6cf --- /dev/null +++ b/web/app/components/app/configuration/config/config-document.spec.tsx @@ -0,0 +1,120 @@ +import type { Mock } from 'vitest' +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import ConfigDocument from './config-document' +import type { FeatureStoreState } from '@/app/components/base/features/store' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' + +const mockUseContext = vi.fn() +vi.mock('use-context-selector', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useContext: (context: unknown) => mockUseContext(context), + } +}) + +const mockUseFeatures = vi.fn() +const mockUseFeaturesStore = vi.fn() +vi.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: (selector: (state: FeatureStoreState) => any) => mockUseFeatures(selector), + useFeaturesStore: () => mockUseFeaturesStore(), +})) + +type SetupOptions = { + isVisible?: boolean + allowedTypes?: SupportUploadFileTypes[] +} + +let mockFeatureStoreState: FeatureStoreState +let mockSetFeatures: Mock +const mockStore = { + getState: vi.fn<() => FeatureStoreState>(() => mockFeatureStoreState), +} + +const setupFeatureStore = (allowedTypes: SupportUploadFileTypes[] = []) => { + mockSetFeatures = vi.fn() + mockFeatureStoreState = { + features: { + file: { + allowed_file_types: allowedTypes, + enabled: allowedTypes.length > 0, + }, + }, + setFeatures: mockSetFeatures, + showFeaturesModal: false, + setShowFeaturesModal: vi.fn(), + } + mockStore.getState.mockImplementation(() => mockFeatureStoreState) + mockUseFeaturesStore.mockReturnValue(mockStore) + mockUseFeatures.mockImplementation(selector => selector(mockFeatureStoreState)) +} + +const renderConfigDocument = (options: SetupOptions = {}) => { + const { + isVisible = true, + allowedTypes = [], + } = options + setupFeatureStore(allowedTypes) + mockUseContext.mockReturnValue({ + isShowDocumentConfig: isVisible, + }) + const user = userEvent.setup() + render() + return { + user, + setFeatures: mockSetFeatures, + } +} + +beforeEach(() => { + vi.clearAllMocks() +}) + +describe('ConfigDocument', () => { + it('should not render when the document configuration is hidden', () => { + renderConfigDocument({ isVisible: false }) + + expect(screen.queryByText('appDebug.feature.documentUpload.title')).not.toBeInTheDocument() + }) + + it('should show document toggle badge when configuration is visible', () => { + renderConfigDocument({ allowedTypes: [SupportUploadFileTypes.document] }) + + expect(screen.getByText('appDebug.feature.documentUpload.title')).toBeInTheDocument() + expect(screen.getByRole('switch')).toHaveAttribute('aria-checked', 'true') + }) + + it('should add document type to allowed list when toggled on', async () => { + const { user, setFeatures } = renderConfigDocument({ allowedTypes: [SupportUploadFileTypes.audio] }) + const toggle = screen.getByRole('switch') + + expect(toggle).toHaveAttribute('aria-checked', 'false') + await user.click(toggle) + + expect(setFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ + allowed_file_types: [SupportUploadFileTypes.audio, SupportUploadFileTypes.document], + enabled: true, + }), + })) + }) + + it('should remove document type but keep file feature enabled when other types remain', async () => { + const { user, setFeatures } = renderConfigDocument({ + allowedTypes: [SupportUploadFileTypes.document, SupportUploadFileTypes.audio], + }) + const toggle = screen.getByRole('switch') + + expect(toggle).toHaveAttribute('aria-checked', 'true') + await user.click(toggle) + + expect(setFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ + allowed_file_types: [SupportUploadFileTypes.audio], + enabled: true, + }), + })) + }) +}) diff --git a/web/app/components/app/configuration/config/index.spec.tsx b/web/app/components/app/configuration/config/index.spec.tsx new file mode 100644 index 0000000000..fc73a52cbd --- /dev/null +++ b/web/app/components/app/configuration/config/index.spec.tsx @@ -0,0 +1,255 @@ +import type { Mock } from 'vitest' +import React from 'react' +import { render, screen } from '@testing-library/react' +import Config from './index' +import type { ModelConfig, PromptVariable } from '@/models/debug' +import * as useContextSelector from 'use-context-selector' +import type { ToolItem } from '@/types/app' +import { AgentStrategy, AppModeEnum, ModelModeType } from '@/types/app' + +vi.mock('use-context-selector', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useContext: vi.fn(), + } +}) + +const mockFormattingDispatcher = vi.fn() +vi.mock('../debug/hooks', () => ({ + __esModule: true, + useFormattingChangedDispatcher: () => mockFormattingDispatcher, +})) + +let latestConfigPromptProps: any +vi.mock('@/app/components/app/configuration/config-prompt', () => ({ + __esModule: true, + default: (props: any) => { + latestConfigPromptProps = props + return
+ }, +})) + +let latestConfigVarProps: any +vi.mock('@/app/components/app/configuration/config-var', () => ({ + __esModule: true, + default: (props: any) => { + latestConfigVarProps = props + return
+ }, +})) + +vi.mock('../dataset-config', () => ({ + __esModule: true, + default: () =>
, +})) + +vi.mock('./agent/agent-tools', () => ({ + __esModule: true, + default: () =>
, +})) + +vi.mock('../config-vision', () => ({ + __esModule: true, + default: () =>
, +})) + +vi.mock('./config-document', () => ({ + __esModule: true, + default: () =>
, +})) + +vi.mock('./config-audio', () => ({ + __esModule: true, + default: () =>
, +})) + +let latestHistoryPanelProps: any +vi.mock('../config-prompt/conversation-history/history-panel', () => ({ + __esModule: true, + default: (props: any) => { + latestHistoryPanelProps = props + return
+ }, +})) + +type MockContext = { + mode: AppModeEnum + isAdvancedMode: boolean + modelModeType: ModelModeType + isAgent: boolean + hasSetBlockStatus: { + context: boolean + history: boolean + query: boolean + } + showHistoryModal: Mock + modelConfig: ModelConfig + setModelConfig: Mock + setPrevPromptConfig: Mock +} + +const createPromptVariable = (overrides: Partial = {}): PromptVariable => ({ + key: 'variable', + name: 'Variable', + type: 'string', + ...overrides, +}) + +const createModelConfig = (overrides: Partial = {}): ModelConfig => ({ + provider: 'openai', + model_id: 'gpt-4', + mode: ModelModeType.chat, + configs: { + prompt_template: 'Hello {{variable}}', + prompt_variables: [createPromptVariable({ key: 'existing' })], + }, + chat_prompt_config: null, + completion_prompt_config: null, + opening_statement: null, + more_like_this: null, + suggested_questions: null, + suggested_questions_after_answer: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null, + retriever_resource: null, + sensitive_word_avoidance: null, + annotation_reply: null, + external_data_tools: null, + system_parameters: { + audio_file_size_limit: 1, + file_size_limit: 1, + image_file_size_limit: 1, + video_file_size_limit: 1, + workflow_file_upload_limit: 1, + }, + dataSets: [], + agentConfig: { + enabled: false, + strategy: AgentStrategy.react, + max_iteration: 1, + tools: [] as ToolItem[], + }, + ...overrides, +}) + +const createContextValue = (overrides: Partial = {}): MockContext => ({ + mode: AppModeEnum.CHAT, + isAdvancedMode: false, + modelModeType: ModelModeType.chat, + isAgent: false, + hasSetBlockStatus: { + context: false, + history: true, + query: false, + }, + showHistoryModal: vi.fn(), + modelConfig: createModelConfig(), + setModelConfig: vi.fn(), + setPrevPromptConfig: vi.fn(), + ...overrides, +}) + +const mockUseContext = useContextSelector.useContext as Mock + +const renderConfig = (contextOverrides: Partial = {}) => { + const contextValue = createContextValue(contextOverrides) + mockUseContext.mockReturnValue(contextValue) + return { + contextValue, + ...render(), + } +} + +beforeEach(() => { + vi.clearAllMocks() + latestConfigPromptProps = undefined + latestConfigVarProps = undefined + latestHistoryPanelProps = undefined +}) + +// Rendering scenarios ensure the layout toggles agent/history specific sections correctly. +describe('Config - Rendering', () => { + it('should render baseline sections without agent specific panels', () => { + renderConfig() + + expect(screen.getByTestId('config-prompt')).toBeInTheDocument() + expect(screen.getByTestId('config-var')).toBeInTheDocument() + expect(screen.getByTestId('dataset-config')).toBeInTheDocument() + expect(screen.getByTestId('config-vision')).toBeInTheDocument() + expect(screen.getByTestId('config-document')).toBeInTheDocument() + expect(screen.getByTestId('config-audio')).toBeInTheDocument() + expect(screen.queryByTestId('agent-tools')).not.toBeInTheDocument() + expect(screen.queryByTestId('history-panel')).not.toBeInTheDocument() + }) + + it('should show AgentTools when app runs in agent mode', () => { + renderConfig({ isAgent: true }) + + expect(screen.getByTestId('agent-tools')).toBeInTheDocument() + }) + + it('should display HistoryPanel only when advanced chat completion values apply', () => { + const showHistoryModal = vi.fn() + renderConfig({ + isAdvancedMode: true, + mode: AppModeEnum.ADVANCED_CHAT, + modelModeType: ModelModeType.completion, + hasSetBlockStatus: { + context: false, + history: false, + query: false, + }, + showHistoryModal, + }) + + expect(screen.getByTestId('history-panel')).toBeInTheDocument() + expect(latestHistoryPanelProps.showWarning).toBe(true) + expect(latestHistoryPanelProps.onShowEditModal).toBe(showHistoryModal) + }) +}) + +// Prompt handling scenarios validate integration between Config and prompt children. +describe('Config - Prompt Handling', () => { + it('should update prompt template and dispatch formatting event when text changes', () => { + const { contextValue } = renderConfig() + const previousVariables = contextValue.modelConfig.configs.prompt_variables + const additions = [createPromptVariable({ key: 'new', name: 'New' })] + + latestConfigPromptProps.onChange('Updated template', additions) + + expect(contextValue.setPrevPromptConfig).toHaveBeenCalledWith(contextValue.modelConfig.configs) + expect(contextValue.setModelConfig).toHaveBeenCalledWith(expect.objectContaining({ + configs: expect.objectContaining({ + prompt_template: 'Updated template', + prompt_variables: [...previousVariables, ...additions], + }), + })) + expect(mockFormattingDispatcher).toHaveBeenCalledTimes(1) + }) + + it('should skip formatting dispatcher when template remains identical', () => { + const { contextValue } = renderConfig() + const unchangedTemplate = contextValue.modelConfig.configs.prompt_template + + latestConfigPromptProps.onChange(unchangedTemplate, [createPromptVariable({ key: 'added' })]) + + expect(contextValue.setPrevPromptConfig).toHaveBeenCalledWith(contextValue.modelConfig.configs) + expect(mockFormattingDispatcher).not.toHaveBeenCalled() + }) + + it('should replace prompt variables when ConfigVar reports updates', () => { + const { contextValue } = renderConfig() + const replacementVariables = [createPromptVariable({ key: 'replacement' })] + + latestConfigVarProps.onPromptVariablesChange(replacementVariables) + + expect(contextValue.setPrevPromptConfig).toHaveBeenCalledWith(contextValue.modelConfig.configs) + expect(contextValue.setModelConfig).toHaveBeenCalledWith(expect.objectContaining({ + configs: expect.objectContaining({ + prompt_variables: replacementVariables, + }), + })) + }) +}) diff --git a/web/app/components/app/configuration/ctrl-btn-group/index.spec.tsx b/web/app/components/app/configuration/ctrl-btn-group/index.spec.tsx index 11cf438974..62c2fe7f45 100644 --- a/web/app/components/app/configuration/ctrl-btn-group/index.spec.tsx +++ b/web/app/components/app/configuration/ctrl-btn-group/index.spec.tsx @@ -3,15 +3,15 @@ import ContrlBtnGroup from './index' describe('ContrlBtnGroup', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) // Rendering fixed action buttons describe('Rendering', () => { it('should render buttons when rendered', () => { // Arrange - const onSave = jest.fn() - const onReset = jest.fn() + const onSave = vi.fn() + const onReset = vi.fn() // Act render() @@ -26,8 +26,8 @@ describe('ContrlBtnGroup', () => { describe('Interactions', () => { it('should invoke callbacks when buttons are clicked', () => { // Arrange - const onSave = jest.fn() - const onReset = jest.fn() + const onSave = vi.fn() + const onReset = vi.fn() render() // Act diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx index 4d92ae4080..9ae664da1c 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx @@ -1,3 +1,4 @@ +import type { MockedFunction } from 'vitest' import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import Item from './index' @@ -9,7 +10,7 @@ import type { RetrievalConfig } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' -jest.mock('../settings-modal', () => ({ +vi.mock('../settings-modal', () => ({ __esModule: true, default: ({ onSave, onCancel, currentDataset }: any) => (
@@ -20,16 +21,16 @@ jest.mock('../settings-modal', () => ({ ), })) -jest.mock('@/hooks/use-breakpoints', () => { - const actual = jest.requireActual('@/hooks/use-breakpoints') +vi.mock('@/hooks/use-breakpoints', async (importOriginal) => { + const actual = await importOriginal() return { __esModule: true, ...actual, - default: jest.fn(() => actual.MediaType.pc), + default: vi.fn(() => actual.MediaType.pc), } }) -const mockedUseBreakpoints = useBreakpoints as jest.MockedFunction +const mockedUseBreakpoints = useBreakpoints as MockedFunction const baseRetrievalConfig: RetrievalConfig = { search_method: RETRIEVE_METHOD.semantic, @@ -123,8 +124,8 @@ const createDataset = (overrides: Partial = {}): DataSet => { } const renderItem = (config: DataSet, props?: Partial>) => { - const onSave = jest.fn() - const onRemove = jest.fn() + const onSave = vi.fn() + const onRemove = vi.fn() render( { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() mockedUseBreakpoints.mockReturnValue(MediaType.pc) }) diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.tsx index 85d46122a3..7fd7011a56 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.tsx @@ -13,7 +13,7 @@ import Drawer from '@/app/components/base/drawer' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Badge from '@/app/components/base/badge' import { useKnowledge } from '@/hooks/use-knowledge' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import AppIcon from '@/app/components/base/app-icon' type ItemProps = { diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx index 69378fbb32..189b4ecaf0 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx @@ -5,8 +5,8 @@ import ContextVar from './index' import type { Props } from './var-picker' // Mock external dependencies only -jest.mock('next/navigation', () => ({ - useRouter: () => ({ push: jest.fn() }), +vi.mock('next/navigation', () => ({ + useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) @@ -18,7 +18,7 @@ type PortalToFollowElemProps = { type PortalToFollowElemTriggerProps = React.HTMLAttributes & { children?: React.ReactNode; asChild?: boolean } type PortalToFollowElemContentProps = React.HTMLAttributes & { children?: React.ReactNode } -jest.mock('@/app/components/base/portal-to-follow-elem', () => { +vi.mock('@/app/components/base/portal-to-follow-elem', () => { const PortalContext = React.createContext({ open: false }) const PortalToFollowElem = ({ children, open }: PortalToFollowElemProps) => { @@ -69,11 +69,11 @@ describe('ContextVar', () => { const defaultProps: Props = { value: 'var1', options: mockOptions, - onChange: jest.fn(), + onChange: vi.fn(), } beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) // Rendering tests (REQUIRED) @@ -165,7 +165,7 @@ describe('ContextVar', () => { describe('User Interactions', () => { it('should call onChange when user selects a different variable', async () => { // Arrange - const onChange = jest.fn() + const onChange = vi.fn() const props = { ...defaultProps, onChange } const user = userEvent.setup() diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.tsx index ebba9c51cb..80cc50acdf 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.tsx @@ -4,7 +4,7 @@ import React from 'react' import { useTranslation } from 'react-i18next' import type { Props } from './var-picker' import VarPicker from './var-picker' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { BracketsX } from '@/app/components/base/icons/src/vender/line/development' import Tooltip from '@/app/components/base/tooltip' diff --git a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx index cb46ce9788..cf52701008 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx @@ -4,8 +4,8 @@ import userEvent from '@testing-library/user-event' import VarPicker, { type Props } from './var-picker' // Mock external dependencies only -jest.mock('next/navigation', () => ({ - useRouter: () => ({ push: jest.fn() }), +vi.mock('next/navigation', () => ({ + useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) @@ -17,7 +17,7 @@ type PortalToFollowElemProps = { type PortalToFollowElemTriggerProps = React.HTMLAttributes & { children?: React.ReactNode; asChild?: boolean } type PortalToFollowElemContentProps = React.HTMLAttributes & { children?: React.ReactNode } -jest.mock('@/app/components/base/portal-to-follow-elem', () => { +vi.mock('@/app/components/base/portal-to-follow-elem', () => { const PortalContext = React.createContext({ open: false }) const PortalToFollowElem = ({ children, open }: PortalToFollowElemProps) => { @@ -69,11 +69,11 @@ describe('VarPicker', () => { const defaultProps: Props = { value: 'var1', options: mockOptions, - onChange: jest.fn(), + onChange: vi.fn(), } beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() }) // Rendering tests (REQUIRED) @@ -201,7 +201,7 @@ describe('VarPicker', () => { describe('User Interactions', () => { it('should open dropdown when clicking the trigger button', async () => { // Arrange - const onChange = jest.fn() + const onChange = vi.fn() const props = { ...defaultProps, onChange } const user = userEvent.setup() @@ -215,7 +215,7 @@ describe('VarPicker', () => { it('should call onChange and close dropdown when selecting an option', async () => { // Arrange - const onChange = jest.fn() + const onChange = vi.fn() const props = { ...defaultProps, onChange } const user = userEvent.setup() diff --git a/web/app/components/app/configuration/dataset-config/context-var/var-picker.tsx b/web/app/components/app/configuration/dataset-config/context-var/var-picker.tsx index c443ea0b1f..f5ea2eaa27 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/var-picker.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/var-picker.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { ChevronDownIcon } from '@heroicons/react/24/outline' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, diff --git a/web/app/components/app/configuration/dataset-config/index.spec.tsx b/web/app/components/app/configuration/dataset-config/index.spec.tsx new file mode 100644 index 0000000000..3e10ed82d7 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/index.spec.tsx @@ -0,0 +1,1049 @@ +import { render, screen, within } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DatasetConfig from './index' +import type { DataSet } from '@/models/datasets' +import { DataSourceType, DatasetPermission } from '@/models/datasets' +import { AppModeEnum } from '@/types/app' +import { ModelModeType } from '@/types/app' +import { RETRIEVE_TYPE } from '@/types/app' +import { ComparisonOperator, LogicalOperator } from '@/app/components/workflow/nodes/knowledge-retrieval/types' +import type { DatasetConfigs } from '@/models/debug' +import { useContext } from 'use-context-selector' +import { hasEditPermissionForDataset } from '@/utils/permission' +import { getSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' + +// Mock external dependencies +vi.mock('@/app/components/workflow/nodes/knowledge-retrieval/utils', () => ({ + getMultipleRetrievalConfig: vi.fn(() => ({ + top_k: 4, + score_threshold: 0.7, + reranking_enable: false, + reranking_model: undefined, + reranking_mode: 'reranking_model', + weights: { weight1: 1.0 }, + })), + getSelectedDatasetsMode: vi.fn(() => ({ + allInternal: true, + allExternal: false, + mixtureInternalAndExternal: false, + mixtureHighQualityAndEconomic: false, + inconsistentEmbeddingModel: false, + })), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModelAndCurrentProviderAndModel: vi.fn(() => ({ + currentModel: { model: 'rerank-model' }, + currentProvider: { provider: 'openai' }, + })), +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: vi.fn((fn: any) => fn({ + userProfile: { + id: 'user-123', + }, + })), +})) + +vi.mock('@/utils/permission', () => ({ + hasEditPermissionForDataset: vi.fn(() => true), +})) + +vi.mock('../debug/hooks', () => ({ + useFormattingChangedDispatcher: vi.fn(() => vi.fn()), +})) + +vi.mock('lodash-es', () => ({ + intersectionBy: vi.fn((...arrays) => { + // Mock realistic intersection behavior based on metadata name + const validArrays = arrays.filter(Array.isArray) + if (validArrays.length === 0) return [] + + // Start with first array and filter down + return validArrays[0].filter((item: any) => { + if (!item || !item.name) return false + + // Only return items that exist in all arrays + return validArrays.every(array => + array.some((otherItem: any) => + otherItem && otherItem.name === item.name, + ), + ) + }) + }), +})) + +vi.mock('uuid', () => ({ + v4: vi.fn(() => 'mock-uuid'), +})) + +// Mock child components +vi.mock('./card-item', () => ({ + __esModule: true, + default: ({ config, onRemove, onSave, editable }: any) => ( +
+ {config.name} + {editable && } + +
+ ), +})) + +vi.mock('./params-config', () => ({ + __esModule: true, + default: ({ disabled, selectedDatasets }: any) => ( + + ), +})) + +vi.mock('./context-var', () => ({ + __esModule: true, + default: ({ value, options, onChange }: any) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter', () => ({ + __esModule: true, + default: ({ + metadataList, + metadataFilterMode, + handleMetadataFilterModeChange, + handleAddCondition, + handleRemoveCondition, + handleUpdateCondition, + handleToggleConditionLogicalOperator, + }: any) => ( +
+ {metadataList.length} + + + + + +
+ ), +})) + +// Mock context +const mockConfigContext: any = { + mode: AppModeEnum.CHAT, + modelModeType: ModelModeType.chat, + isAgent: false, + dataSets: [], + setDataSets: vi.fn(), + modelConfig: { + configs: { + prompt_variables: [], + }, + }, + setModelConfig: vi.fn(), + showSelectDataSet: vi.fn(), + datasetConfigs: { + retrieval_model: RETRIEVE_TYPE.multiWay, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + metadata_filtering_mode: 'disabled' as any, + metadata_filtering_conditions: undefined, + datasets: { + datasets: [], + }, + } as DatasetConfigs, + datasetConfigsRef: { + current: { + retrieval_model: RETRIEVE_TYPE.multiWay, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + metadata_filtering_mode: 'disabled' as any, + metadata_filtering_conditions: undefined, + datasets: { + datasets: [], + }, + } as DatasetConfigs, + }, + setDatasetConfigs: vi.fn(), + setRerankSettingModalOpen: vi.fn(), +} + +vi.mock('@/context/debug-configuration', () => ({ + __esModule: true, + default: ({ children }: any) => ( +
+ {children} +
+ ), +})) + +vi.mock('use-context-selector', () => ({ + useContext: vi.fn(() => mockConfigContext), +})) + +const createMockDataset = (overrides: Partial = {}): DataSet => { + const defaultDataset: DataSet = { + id: 'dataset-1', + name: 'Test Dataset', + indexing_status: 'completed' as any, + icon_info: { + icon: '📘', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + }, + description: 'Test dataset description', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as any, + author_name: 'Test Author', + created_by: 'user-123', + updated_by: 'user-123', + updated_at: Date.now(), + app_count: 0, + doc_form: 'text' as any, + document_count: 10, + total_document_count: 10, + total_available_documents: 10, + word_count: 1000, + provider: 'dify', + embedding_model: 'text-embedding-ada-002', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: 'semantic_search' as any, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + }, + retrieval_model: { + search_method: 'semantic_search' as any, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 2, + score_threshold: 0.5, + score_threshold_enabled: true, + }, + built_in_field_enabled: true, + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + keyword_number: 3, + pipeline_id: 'pipeline-123', + is_published: true, + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...overrides, + } + return defaultDataset +} + +const renderDatasetConfig = (contextOverrides: Partial = {}) => { + const mergedContext = { ...mockConfigContext, ...contextOverrides } + vi.mocked(useContext).mockReturnValue(mergedContext) + + return render() +} + +describe('DatasetConfig', () => { + beforeEach(() => { + vi.clearAllMocks() + mockConfigContext.dataSets = [] + mockConfigContext.setDataSets = vi.fn() + mockConfigContext.setModelConfig = vi.fn() + mockConfigContext.setDatasetConfigs = vi.fn() + mockConfigContext.setRerankSettingModalOpen = vi.fn() + }) + + describe('Rendering', () => { + it('should render dataset configuration panel when component mounts', () => { + renderDatasetConfig() + + expect(screen.getByText('appDebug.feature.dataSet.title')).toBeInTheDocument() + }) + + it('should display empty state message when no datasets are configured', () => { + renderDatasetConfig() + + expect(screen.getByText(/no.*data/i)).toBeInTheDocument() + expect(screen.getByTestId('params-config')).toBeDisabled() + }) + + it('should render dataset cards and enable parameters when datasets exist', () => { + const dataset = createMockDataset() + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + expect(screen.getByText(dataset.name)).toBeInTheDocument() + expect(screen.getByTestId('params-config')).not.toBeDisabled() + }) + + it('should show configuration title and add dataset button in header', () => { + renderDatasetConfig() + + expect(screen.getByText('appDebug.feature.dataSet.title')).toBeInTheDocument() + expect(screen.getByText('common.operation.add')).toBeInTheDocument() + }) + + it('should hide parameters configuration when in agent mode', () => { + renderDatasetConfig({ + isAgent: true, + }) + + expect(screen.queryByTestId('params-config')).not.toBeInTheDocument() + }) + }) + + describe('Dataset Management', () => { + it('should open dataset selection modal when add button is clicked', async () => { + const user = userEvent.setup() + renderDatasetConfig() + + const addButton = screen.getByText('common.operation.add') + await user.click(addButton) + + expect(mockConfigContext.showSelectDataSet).toHaveBeenCalledTimes(1) + }) + + it('should remove dataset and update configuration when remove button is clicked', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + renderDatasetConfig({ + dataSets: [dataset], + }) + + const removeButton = screen.getByText('Remove') + await user.click(removeButton) + + expect(mockConfigContext.setDataSets).toHaveBeenCalledWith([]) + // Note: setDatasetConfigs is also called but its exact parameters depend on + // the retrieval config calculation which involves complex mocked utilities + }) + + it('should trigger rerank setting modal when removing dataset requires rerank configuration', async () => { + const user = userEvent.setup() + + // Mock scenario that triggers rerank modal + // @ts-expect-error - same as above + vi.mocked(getSelectedDatasetsMode).mockReturnValue({ + allInternal: false, + allExternal: true, + mixtureInternalAndExternal: false, + mixtureHighQualityAndEconomic: false, + inconsistentEmbeddingModel: false, + }) + + const dataset = createMockDataset() + renderDatasetConfig({ + dataSets: [dataset], + }) + + const removeButton = screen.getByText('Remove') + await user.click(removeButton) + + expect(mockConfigContext.setRerankSettingModalOpen).toHaveBeenCalledWith(true) + }) + + it('should handle dataset save', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + + renderDatasetConfig({ + dataSets: [dataset], + }) + + // Mock the onSave in card-item component - it will pass the original dataset + const editButton = screen.getByText('Edit') + await user.click(editButton) + + expect(mockConfigContext.setDataSets).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ + id: dataset.id, + name: dataset.name, + editable: true, + }), + ]), + ) + }) + + it('should format datasets with edit permission', () => { + const dataset = createMockDataset({ + created_by: 'user-123', + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + }) + + describe('Context Variables', () => { + it('should show context variable selector in completion mode with datasets', () => { + const dataset = createMockDataset() + renderDatasetConfig({ + mode: AppModeEnum.COMPLETION, + dataSets: [dataset], + modelConfig: { + configs: { + prompt_variables: [ + { key: 'query', name: 'Query', type: 'string', is_context_var: false }, + { key: 'context', name: 'Context', type: 'string', is_context_var: true }, + ], + }, + }, + }) + + expect(screen.getByTestId('context-var')).toBeInTheDocument() + // Should find the selected context variable in the options + expect(screen.getByText('Select context variable')).toBeInTheDocument() + }) + + it('should not show context variable selector in chat mode', () => { + const dataset = createMockDataset() + renderDatasetConfig({ + mode: AppModeEnum.CHAT, + dataSets: [dataset], + modelConfig: { + configs: { + prompt_variables: [ + { key: 'query', name: 'Query', type: 'string', is_context_var: false }, + ], + }, + }, + }) + + expect(screen.queryByTestId('context-var')).not.toBeInTheDocument() + }) + + it('should handle context variable selection', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + renderDatasetConfig({ + mode: AppModeEnum.COMPLETION, + dataSets: [dataset], + modelConfig: { + configs: { + prompt_variables: [ + { key: 'query', name: 'Query', type: 'string', is_context_var: false }, + { key: 'context', name: 'Context', type: 'string', is_context_var: true }, + ], + }, + }, + }) + + const select = screen.getByTestId('context-var') + await user.selectOptions(select, 'query') + + expect(mockConfigContext.setModelConfig).toHaveBeenCalled() + }) + }) + + describe('Metadata Filtering', () => { + it('should render metadata filter component', () => { + const dataset = createMockDataset({ + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + expect(screen.getByTestId('metadata-list-count')).toHaveTextContent('2') // both 'category' and 'priority' + }) + + it('should handle metadata filter mode change', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + const updatedDatasetConfigs = { + ...mockConfigContext.datasetConfigs, + metadata_filtering_mode: 'disabled' as any, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: updatedDatasetConfigs, + }) + + // Update the ref to match + mockConfigContext.datasetConfigsRef.current = updatedDatasetConfigs + + const select = within(screen.getByTestId('metadata-filter')).getByDisplayValue('Disabled') + await user.selectOptions(select, 'automatic') + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_mode: 'automatic', + }), + ) + }) + + it('should handle adding metadata conditions', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + const baseDatasetConfigs = { + ...mockConfigContext.datasetConfigs, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: baseDatasetConfigs, + }) + + // Update the ref to match + mockConfigContext.datasetConfigsRef.current = baseDatasetConfigs + + const addButton = within(screen.getByTestId('metadata-filter')).getByText('Add Condition') + await user.click(addButton) + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_conditions: expect.objectContaining({ + logical_operator: LogicalOperator.and, + conditions: expect.arrayContaining([ + expect.objectContaining({ + id: 'mock-uuid', + name: 'test', + comparison_operator: ComparisonOperator.is, + }), + ]), + }), + }), + ) + }) + + it('should handle removing metadata conditions', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + + const datasetConfigsWithConditions = { + ...mockConfigContext.datasetConfigs, + metadata_filtering_conditions: { + logical_operator: LogicalOperator.and, + conditions: [ + { id: 'condition-id', name: 'test', comparison_operator: ComparisonOperator.is }, + ], + }, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: datasetConfigsWithConditions, + }) + + // Update ref to match datasetConfigs + mockConfigContext.datasetConfigsRef.current = datasetConfigsWithConditions + + const removeButton = within(screen.getByTestId('metadata-filter')).getByText('Remove Condition') + await user.click(removeButton) + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_conditions: expect.objectContaining({ + conditions: [], + }), + }), + ) + }) + + it('should handle updating metadata conditions', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + + const datasetConfigsWithConditions = { + ...mockConfigContext.datasetConfigs, + metadata_filtering_conditions: { + logical_operator: LogicalOperator.and, + conditions: [ + { id: 'condition-id', name: 'test', comparison_operator: ComparisonOperator.is }, + ], + }, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: datasetConfigsWithConditions, + }) + + mockConfigContext.datasetConfigsRef.current = datasetConfigsWithConditions + + const updateButton = within(screen.getByTestId('metadata-filter')).getByText('Update Condition') + await user.click(updateButton) + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_conditions: expect.objectContaining({ + conditions: expect.arrayContaining([ + expect.objectContaining({ + name: 'updated', + }), + ]), + }), + }), + ) + }) + + it('should handle toggling logical operator', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + + const datasetConfigsWithConditions = { + ...mockConfigContext.datasetConfigs, + metadata_filtering_conditions: { + logical_operator: LogicalOperator.and, + conditions: [ + { id: 'condition-id', name: 'test', comparison_operator: ComparisonOperator.is }, + ], + }, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: datasetConfigsWithConditions, + }) + + mockConfigContext.datasetConfigsRef.current = datasetConfigsWithConditions + + const toggleButton = within(screen.getByTestId('metadata-filter')).getByText('Toggle Operator') + await user.click(toggleButton) + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_conditions: expect.objectContaining({ + logical_operator: LogicalOperator.or, + }), + }), + ) + }) + }) + + describe('Edge Cases', () => { + it('should handle null doc_metadata gracefully', () => { + const dataset = createMockDataset({ + doc_metadata: undefined, + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + expect(screen.getByTestId('metadata-list-count')).toHaveTextContent('0') + }) + + it('should handle empty doc_metadata array', () => { + const dataset = createMockDataset({ + doc_metadata: [], + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + expect(screen.getByTestId('metadata-list-count')).toHaveTextContent('0') + }) + + it('should handle missing userProfile', () => { + vi.mocked(useContext).mockReturnValue({ + ...mockConfigContext, + userProfile: null, + }) + + const dataset = createMockDataset() + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + + it('should handle missing datasetConfigsRef gracefully', () => { + const dataset = createMockDataset() + + // Test with undefined datasetConfigsRef - component renders without immediate error + // The component will fail on interaction due to non-null assertions in handlers + expect(() => { + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigsRef: undefined as any, + }) + }).not.toThrow() + + // The component currently expects datasetConfigsRef to exist for interactions + // This test documents the current behavior and requirements + }) + + it('should handle missing prompt_variables', () => { + // Context var is only shown when datasets exist AND there are prompt_variables + // Test with no datasets to ensure context var is not shown + renderDatasetConfig({ + mode: AppModeEnum.COMPLETION, + dataSets: [], + modelConfig: { + configs: { + prompt_variables: [], + }, + }, + }) + + expect(screen.queryByTestId('context-var')).not.toBeInTheDocument() + }) + }) + + describe('Component Integration', () => { + it('should integrate with card item component', () => { + const datasets = [ + createMockDataset({ id: 'ds1', name: 'Dataset 1' }), + createMockDataset({ id: 'ds2', name: 'Dataset 2' }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + expect(screen.getByTestId('card-item-ds1')).toBeInTheDocument() + expect(screen.getByTestId('card-item-ds2')).toBeInTheDocument() + expect(screen.getByText('Dataset 1')).toBeInTheDocument() + expect(screen.getByText('Dataset 2')).toBeInTheDocument() + }) + + it('should integrate with params config component', () => { + const datasets = [ + createMockDataset(), + createMockDataset({ id: 'ds2' }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + const paramsConfig = screen.getByTestId('params-config') + expect(paramsConfig).toBeInTheDocument() + expect(paramsConfig).toHaveTextContent('Params (2)') + expect(paramsConfig).not.toBeDisabled() + }) + + it('should integrate with metadata filter component', () => { + const datasets = [ + createMockDataset({ + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'tags', type: 'string' } as any, + ], + }), + createMockDataset({ + id: 'ds2', + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + const metadataFilter = screen.getByTestId('metadata-filter') + expect(metadataFilter).toBeInTheDocument() + // Should show intersection (only 'category') + expect(screen.getByTestId('metadata-list-count')).toHaveTextContent('1') + }) + }) + + describe('Model Configuration', () => { + it('should handle metadata model change', () => { + const dataset = createMockDataset() + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: { + ...mockConfigContext.datasetConfigs, + metadata_model_config: { + provider: 'openai', + name: 'gpt-3.5-turbo', + mode: AppModeEnum.CHAT, + completion_params: { temperature: 0.7 }, + }, + }, + }) + + // The component would need to expose this functionality through the metadata filter + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + }) + + it('should handle metadata completion params change', () => { + const dataset = createMockDataset() + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: { + ...mockConfigContext.datasetConfigs, + metadata_model_config: { + provider: 'openai', + name: 'gpt-3.5-turbo', + mode: AppModeEnum.CHAT, + completion_params: { temperature: 0.5, max_tokens: 100 }, + }, + }, + }) + + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + }) + }) + + describe('Permission Handling', () => { + it('should hide edit options when user lacks permission', () => { + vi.mocked(hasEditPermissionForDataset).mockReturnValue(false) + + const dataset = createMockDataset({ + created_by: 'other-user', + permission: DatasetPermission.onlyMe, + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + // The editable property should be false when no permission + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + + it('should show readonly state for non-editable datasets', () => { + vi.mocked(hasEditPermissionForDataset).mockReturnValue(false) + + const dataset = createMockDataset({ + created_by: 'admin', + permission: DatasetPermission.allTeamMembers, + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + + it('should allow editing when user has partial member permission', () => { + vi.mocked(hasEditPermissionForDataset).mockReturnValue(true) + + const dataset = createMockDataset({ + created_by: 'admin', + permission: DatasetPermission.partialMembers, + partial_member_list: ['user-123'], + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + }) + + describe('Dataset Reordering and Management', () => { + it('should maintain dataset order after updates', () => { + const datasets = [ + createMockDataset({ id: 'ds1', name: 'Dataset 1' }), + createMockDataset({ id: 'ds2', name: 'Dataset 2' }), + createMockDataset({ id: 'ds3', name: 'Dataset 3' }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + // Verify order is maintained + expect(screen.getByText('Dataset 1')).toBeInTheDocument() + expect(screen.getByText('Dataset 2')).toBeInTheDocument() + expect(screen.getByText('Dataset 3')).toBeInTheDocument() + }) + + it('should handle multiple dataset operations correctly', async () => { + const user = userEvent.setup() + const datasets = [ + createMockDataset({ id: 'ds1', name: 'Dataset 1' }), + createMockDataset({ id: 'ds2', name: 'Dataset 2' }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + // Remove first dataset + const removeButton1 = screen.getAllByText('Remove')[0] + await user.click(removeButton1) + + expect(mockConfigContext.setDataSets).toHaveBeenCalledWith([datasets[1]]) + }) + }) + + describe('Complex Configuration Scenarios', () => { + it('should handle multiple retrieval methods in configuration', () => { + const datasets = [ + createMockDataset({ + id: 'ds1', + retrieval_model: { + search_method: 'semantic_search' as any, + reranking_enable: true, + reranking_model: { + reranking_provider_name: 'cohere', + reranking_model_name: 'rerank-v3.5', + }, + top_k: 5, + score_threshold_enabled: true, + score_threshold: 0.8, + }, + }), + createMockDataset({ + id: 'ds2', + retrieval_model: { + search_method: 'full_text_search' as any, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + }, + }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + expect(screen.getByTestId('params-config')).toHaveTextContent('Params (2)') + }) + + it('should handle external knowledge base integration', () => { + const externalDataset = createMockDataset({ + provider: 'notion', + external_knowledge_info: { + external_knowledge_id: 'notion-123', + external_knowledge_api_id: 'api-456', + external_knowledge_api_name: 'Notion Integration', + external_knowledge_api_endpoint: 'https://api.notion.com', + }, + }) + + renderDatasetConfig({ + dataSets: [externalDataset], + }) + + expect(screen.getByTestId(`card-item-${externalDataset.id}`)).toBeInTheDocument() + expect(screen.getByText(externalDataset.name)).toBeInTheDocument() + }) + }) + + describe('Performance and Error Handling', () => { + it('should handle large dataset lists efficiently', () => { + // Create many datasets to test performance + const manyDatasets = Array.from({ length: 50 }, (_, i) => + createMockDataset({ + id: `ds-${i}`, + name: `Dataset ${i}`, + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }), + ) + + renderDatasetConfig({ + dataSets: manyDatasets, + }) + + expect(screen.getByTestId('params-config')).toHaveTextContent('Params (50)') + }) + + it('should handle metadata intersection calculation efficiently', () => { + const datasets = [ + createMockDataset({ + id: 'ds1', + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'tags', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }), + createMockDataset({ + id: 'ds2', + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'status', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + // Should calculate intersection correctly + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx new file mode 100644 index 0000000000..58cc2ac81c --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx @@ -0,0 +1,391 @@ +import type { MockInstance, MockedFunction } from 'vitest' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import ConfigContent from './config-content' +import type { DataSet } from '@/models/datasets' +import { ChunkingMode, DataSourceType, DatasetPermission, RerankingModeEnum, WeightedScoreEnum } from '@/models/datasets' +import type { DatasetConfigs } from '@/models/debug' +import { RETRIEVE_METHOD, RETRIEVE_TYPE } from '@/types/app' +import type { RetrievalConfig } from '@/types/app' +import Toast from '@/app/components/base/toast' +import type { IndexingType } from '@/app/components/datasets/create/step-two' +import { + useCurrentProviderAndModel, + useModelListAndDefaultModelAndCurrentProviderAndModel, +} from '@/app/components/header/account-setting/model-provider-page/hooks' + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => { + type Props = { + defaultModel?: { provider: string; model: string } + onSelect?: (model: { provider: string; model: string }) => void + } + + const MockModelSelector = ({ defaultModel, onSelect }: Props) => ( + + ) + + return { + __esModule: true, + default: MockModelSelector, + } +}) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({ + __esModule: true, + default: () =>
, +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModelAndCurrentProviderAndModel: vi.fn(), + useCurrentProviderAndModel: vi.fn(), +})) + +const mockedUseModelListAndDefaultModelAndCurrentProviderAndModel = useModelListAndDefaultModelAndCurrentProviderAndModel as MockedFunction +const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as MockedFunction + +let toastNotifySpy: MockInstance + +const baseRetrievalConfig: RetrievalConfig = { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: 'provider', + reranking_model_name: 'rerank-model', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0, +} + +const defaultIndexingTechnique: IndexingType = 'high_quality' as IndexingType + +const createDataset = (overrides: Partial = {}): DataSet => { + const { + retrieval_model, + retrieval_model_dict, + icon_info, + ...restOverrides + } = overrides + + const resolvedRetrievalModelDict = { + ...baseRetrievalConfig, + ...retrieval_model_dict, + } + const resolvedRetrievalModel = { + ...baseRetrievalConfig, + ...(retrieval_model ?? retrieval_model_dict), + } + + const defaultIconInfo = { + icon: '📘', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + } + + const resolvedIconInfo = ('icon_info' in overrides) + ? icon_info + : defaultIconInfo + + return { + id: 'dataset-id', + name: 'Dataset Name', + indexing_status: 'completed', + icon_info: resolvedIconInfo as DataSet['icon_info'], + description: 'A test dataset', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: defaultIndexingTechnique, + author_name: 'author', + created_by: 'creator', + updated_by: 'updater', + updated_at: 0, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + total_available_documents: 0, + word_count: 0, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: resolvedRetrievalModelDict, + retrieval_model: resolvedRetrievalModel, + tags: [], + external_knowledge_info: { + external_knowledge_id: 'external-id', + external_knowledge_api_id: 'api-id', + external_knowledge_api_name: 'api-name', + external_knowledge_api_endpoint: 'https://endpoint', + }, + external_retrieval_model: { + top_k: 2, + score_threshold: 0.5, + score_threshold_enabled: true, + }, + built_in_field_enabled: true, + doc_metadata: [], + keyword_number: 3, + pipeline_id: 'pipeline-id', + is_published: true, + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...restOverrides, + } +} + +const createDatasetConfigs = (overrides: Partial = {}): DatasetConfigs => { + return { + retrieval_model: RETRIEVE_TYPE.multiWay, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0, + datasets: { + datasets: [], + }, + reranking_mode: RerankingModeEnum.WeightedScore, + weights: { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: 0.5, + embedding_provider_name: 'openai', + embedding_model_name: 'text-embedding', + }, + keyword_setting: { + keyword_weight: 0.5, + }, + }, + reranking_enable: false, + ...overrides, + } +} + +describe('ConfigContent', () => { + beforeEach(() => { + vi.clearAllMocks() + toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) + mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ + modelList: [], + defaultModel: undefined, + currentProvider: undefined, + currentModel: undefined, + }) + mockedUseCurrentProviderAndModel.mockReturnValue({ + currentProvider: undefined, + currentModel: undefined, + }) + }) + + afterEach(() => { + toastNotifySpy.mockRestore() + }) + + // State management + describe('Effects', () => { + it('should normalize oneWay retrieval mode to multiWay', async () => { + // Arrange + const onChange = vi.fn<(configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void>() + const datasetConfigs = createDatasetConfigs({ retrieval_model: RETRIEVE_TYPE.oneWay }) + + // Act + render() + + // Assert + await waitFor(() => { + expect(onChange).toHaveBeenCalled() + }) + const [nextConfigs] = onChange.mock.calls[0] + expect(nextConfigs.retrieval_model).toBe(RETRIEVE_TYPE.multiWay) + }) + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render weighted score panel when datasets are high-quality and consistent', () => { + // Arrange + const onChange = vi.fn<(configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void>() + const datasetConfigs = createDatasetConfigs({ + reranking_mode: RerankingModeEnum.WeightedScore, + }) + const selectedDatasets: DataSet[] = [ + createDataset({ + indexing_technique: 'high_quality' as IndexingType, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + retrieval_model_dict: { + ...baseRetrievalConfig, + search_method: RETRIEVE_METHOD.semantic, + }, + }), + ] + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('dataset.weightedScore.title')).toBeInTheDocument() + expect(screen.getByText('common.modelProvider.rerankModel.key')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.semantic')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.keyword')).toBeInTheDocument() + }) + }) + + // User interactions + describe('User Interactions', () => { + it('should update weights when user changes weighted score slider', async () => { + // Arrange + const user = userEvent.setup() + const onChange = vi.fn<(configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void>() + const datasetConfigs = createDatasetConfigs({ + reranking_mode: RerankingModeEnum.WeightedScore, + weights: { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: 0.5, + embedding_provider_name: 'openai', + embedding_model_name: 'text-embedding', + }, + keyword_setting: { + keyword_weight: 0.5, + }, + }, + }) + const selectedDatasets: DataSet[] = [ + createDataset({ + indexing_technique: 'high_quality' as IndexingType, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + retrieval_model_dict: { + ...baseRetrievalConfig, + search_method: RETRIEVE_METHOD.semantic, + }, + }), + ] + + // Act + render( + , + ) + + const weightedScoreSlider = screen.getAllByRole('slider') + .find(slider => slider.getAttribute('aria-valuemax') === '1') + expect(weightedScoreSlider).toBeDefined() + await user.click(weightedScoreSlider!) + const callsBefore = onChange.mock.calls.length + await user.keyboard('{ArrowRight}') + + // Assert + expect(onChange.mock.calls.length).toBeGreaterThan(callsBefore) + const [nextConfigs] = onChange.mock.calls.at(-1) ?? [] + expect(nextConfigs?.weights?.vector_setting.vector_weight).toBeCloseTo(0.6, 5) + expect(nextConfigs?.weights?.keyword_setting.keyword_weight).toBeCloseTo(0.4, 5) + }) + + it('should warn when switching to rerank model mode without a valid model', async () => { + // Arrange + const user = userEvent.setup() + const onChange = vi.fn<(configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void>() + const datasetConfigs = createDatasetConfigs({ + reranking_mode: RerankingModeEnum.WeightedScore, + }) + const selectedDatasets: DataSet[] = [ + createDataset({ + indexing_technique: 'high_quality' as IndexingType, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + retrieval_model_dict: { + ...baseRetrievalConfig, + search_method: RETRIEVE_METHOD.semantic, + }, + }), + ] + + // Act + render( + , + ) + await user.click(screen.getByText('common.modelProvider.rerankModel.key')) + + // Assert + expect(toastNotifySpy).toHaveBeenCalledWith({ + type: 'error', + message: 'workflow.errorMsg.rerankModelRequired', + }) + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_mode: RerankingModeEnum.RerankingModel, + }), + ) + }) + + it('should warn when enabling rerank without a valid model in manual toggle mode', async () => { + // Arrange + const user = userEvent.setup() + const onChange = vi.fn<(configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void>() + const datasetConfigs = createDatasetConfigs({ + reranking_enable: false, + }) + const selectedDatasets: DataSet[] = [ + createDataset({ + indexing_technique: 'economy' as IndexingType, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + retrieval_model_dict: { + ...baseRetrievalConfig, + search_method: RETRIEVE_METHOD.semantic, + }, + }), + ] + + // Act + render( + , + ) + await user.click(screen.getByRole('switch')) + + // Assert + expect(toastNotifySpy).toHaveBeenCalledWith({ + type: 'error', + message: 'workflow.errorMsg.rerankModelRequired', + }) + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_enable: true, + }), + ) + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 8e06d6c901..c7a43fbfbd 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -20,7 +20,7 @@ import type { DataSet, } from '@/models/datasets' import { RerankingModeEnum } from '@/models/datasets' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks' import Switch from '@/app/components/base/switch' import Toast from '@/app/components/base/toast' diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx new file mode 100644 index 0000000000..cd4d3c6006 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx @@ -0,0 +1,266 @@ +import type { MockInstance, MockedFunction } from 'vitest' +import * as React from 'react' +import { render, screen, waitFor, within } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import ParamsConfig from './index' +import ConfigContext from '@/context/debug-configuration' +import type { DatasetConfigs } from '@/models/debug' +import { RerankingModeEnum } from '@/models/datasets' +import { RETRIEVE_TYPE } from '@/types/app' +import Toast from '@/app/components/base/toast' +import { + useCurrentProviderAndModel, + useModelListAndDefaultModelAndCurrentProviderAndModel, +} from '@/app/components/header/account-setting/model-provider-page/hooks' + +vi.mock('@headlessui/react', () => ({ + Dialog: ({ children, className }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + DialogPanel: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + DialogTitle: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + Transition: ({ show, children }: { show: boolean; children: React.ReactNode }) => (show ? <>{children} : null), + TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}, + Switch: ({ checked, onChange, children, ...props }: { checked: boolean; onChange?: (value: boolean) => void; children?: React.ReactNode }) => ( + + ), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModelAndCurrentProviderAndModel: vi.fn(), + useCurrentProviderAndModel: vi.fn(), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => { + type Props = { + defaultModel?: { provider: string; model: string } + onSelect?: (model: { provider: string; model: string }) => void + } + + const MockModelSelector = ({ defaultModel, onSelect }: Props) => ( + + ) + + return { + __esModule: true, + default: MockModelSelector, + } +}) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({ + __esModule: true, + default: () =>
, +})) + +const mockedUseModelListAndDefaultModelAndCurrentProviderAndModel = useModelListAndDefaultModelAndCurrentProviderAndModel as MockedFunction +const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as MockedFunction +let toastNotifySpy: MockInstance + +const createDatasetConfigs = (overrides: Partial = {}): DatasetConfigs => { + return { + retrieval_model: RETRIEVE_TYPE.multiWay, + reranking_model: { + reranking_provider_name: 'provider', + reranking_model_name: 'rerank-model', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0, + datasets: { + datasets: [], + }, + reranking_enable: false, + reranking_mode: RerankingModeEnum.RerankingModel, + ...overrides, + } +} + +const renderParamsConfig = ({ + datasetConfigs = createDatasetConfigs(), + initialModalOpen = false, + disabled, +}: { + datasetConfigs?: DatasetConfigs + initialModalOpen?: boolean + disabled?: boolean +} = {}) => { + const Wrapper = ({ children }: { children: React.ReactNode }) => { + const [datasetConfigsState, setDatasetConfigsState] = React.useState(datasetConfigs) + const [modalOpen, setModalOpen] = React.useState(initialModalOpen) + + const contextValue = { + datasetConfigs: datasetConfigsState, + setDatasetConfigs: (next: DatasetConfigs) => { + setDatasetConfigsState(next) + }, + rerankSettingModalOpen: modalOpen, + setRerankSettingModalOpen: (open: boolean) => { + setModalOpen(open) + }, + } as unknown as React.ComponentProps['value'] + + return ( + + {children} + + ) + } + + return render( + , + { wrapper: Wrapper }, + ) +} + +describe('dataset-config/params-config', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useRealTimers() + toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) + mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ + modelList: [], + defaultModel: undefined, + currentProvider: undefined, + currentModel: undefined, + }) + mockedUseCurrentProviderAndModel.mockReturnValue({ + currentProvider: undefined, + currentModel: undefined, + }) + }) + + afterEach(() => { + toastNotifySpy.mockRestore() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should disable settings trigger when disabled is true', () => { + // Arrange + renderParamsConfig({ disabled: true }) + + // Assert + expect(screen.getByRole('button', { name: 'dataset.retrievalSettings' })).toBeDisabled() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should open modal and persist changes when save is clicked', async () => { + // Arrange + renderParamsConfig() + const user = userEvent.setup() + + // Act + await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) + const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const dialogScope = within(dialog) + + const incrementButtons = dialogScope.getAllByRole('button', { name: 'increment' }) + await user.click(incrementButtons[0]) + + await waitFor(() => { + const [topKInput] = dialogScope.getAllByRole('spinbutton') + expect(topKInput).toHaveValue(5) + }) + + await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) + + await waitFor(() => { + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) + const reopenedDialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const reopenedScope = within(reopenedDialog) + const [reopenedTopKInput] = reopenedScope.getAllByRole('spinbutton') + + // Assert + expect(reopenedTopKInput).toHaveValue(5) + }) + + it('should discard changes when cancel is clicked', async () => { + // Arrange + renderParamsConfig() + const user = userEvent.setup() + + // Act + await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) + const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const dialogScope = within(dialog) + + const incrementButtons = dialogScope.getAllByRole('button', { name: 'increment' }) + await user.click(incrementButtons[0]) + + await waitFor(() => { + const [topKInput] = dialogScope.getAllByRole('spinbutton') + expect(topKInput).toHaveValue(5) + }) + + const cancelButton = await dialogScope.findByRole('button', { name: 'common.operation.cancel' }) + await user.click(cancelButton) + await waitFor(() => { + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + + // Re-open and verify the original value remains. + await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) + const reopenedDialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const reopenedScope = within(reopenedDialog) + const [reopenedTopKInput] = reopenedScope.getAllByRole('spinbutton') + + // Assert + expect(reopenedTopKInput).toHaveValue(4) + }) + + it('should prevent saving when rerank model is required but invalid', async () => { + // Arrange + renderParamsConfig({ + datasetConfigs: createDatasetConfigs({ + reranking_enable: true, + reranking_mode: RerankingModeEnum.RerankingModel, + }), + initialModalOpen: true, + }) + const user = userEvent.setup() + + // Act + const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const dialogScope = within(dialog) + await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(toastNotifySpy).toHaveBeenCalledWith({ + type: 'error', + message: 'appDebug.datasetConfig.rerankModelRequired', + }) + expect(screen.getByRole('dialog')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index df2b4293c4..24da958217 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -4,7 +4,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { RiEqualizer2Line } from '@remixicon/react' import ConfigContent from './config-content' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import ConfigContext from '@/context/debug-configuration' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx new file mode 100644 index 0000000000..7729830348 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx @@ -0,0 +1,81 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import WeightedScore from './weighted-score' + +describe('WeightedScore', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render semantic and keyword weights', () => { + // Arrange + const onChange = vi.fn<(arg: { value: number[] }) => void>() + const value = { value: [0.3, 0.7] } + + // Act + render() + + // Assert + expect(screen.getByText('dataset.weightedScore.semantic')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.keyword')).toBeInTheDocument() + expect(screen.getByText('0.3')).toBeInTheDocument() + expect(screen.getByText('0.7')).toBeInTheDocument() + }) + + it('should format a weight of 1 as 1.0', () => { + // Arrange + const onChange = vi.fn<(arg: { value: number[] }) => void>() + const value = { value: [1, 0] } + + // Act + render() + + // Assert + expect(screen.getByText('1.0')).toBeInTheDocument() + expect(screen.getByText('0')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should emit complementary weights when the slider value changes', async () => { + // Arrange + const onChange = vi.fn<(arg: { value: number[] }) => void>() + const value = { value: [0.5, 0.5] } + const user = userEvent.setup() + render() + + // Act + await user.tab() + const slider = screen.getByRole('slider') + expect(slider).toHaveFocus() + const callsBefore = onChange.mock.calls.length + await user.keyboard('{ArrowRight}') + + // Assert + expect(onChange.mock.calls.length).toBeGreaterThan(callsBefore) + const lastCall = onChange.mock.calls.at(-1)?.[0] + expect(lastCall?.value[0]).toBeCloseTo(0.6, 5) + expect(lastCall?.value[1]).toBeCloseTo(0.4, 5) + }) + + it('should not call onChange when readonly is true', async () => { + // Arrange + const onChange = vi.fn<(arg: { value: number[] }) => void>() + const value = { value: [0.5, 0.5] } + const user = userEvent.setup() + render() + + // Act + await user.tab() + const slider = screen.getByRole('slider') + expect(slider).toHaveFocus() + await user.keyboard('{ArrowRight}') + + // Assert + expect(onChange).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx index ebfa3b1e12..459623104d 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx @@ -2,7 +2,7 @@ import { memo } from 'react' import { useTranslation } from 'react-i18next' import './weighted-score.css' import Slider from '@/app/components/base/slider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { noop } from 'lodash-es' const formatNumber = (value: number) => { diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 6857c38e1e..f02fdcb5d7 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -10,7 +10,7 @@ import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' import Badge from '@/app/components/base/badge' import { useKnowledge } from '@/hooks/use-knowledge' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import AppIcon from '@/app/components/base/app-icon' import { useInfiniteDatasets } from '@/service/knowledge/use-dataset' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx new file mode 100644 index 0000000000..f35b1b7def --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx @@ -0,0 +1,538 @@ +import type { MockedFunction } from 'vitest' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SettingsModal from './index' +import { ToastContext } from '@/app/components/base/toast' +import type { DataSet } from '@/models/datasets' +import { ChunkingMode, DataSourceType, DatasetPermission, RerankingModeEnum } from '@/models/datasets' +import { IndexingType } from '@/app/components/datasets/create/step-two' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { updateDatasetSetting } from '@/service/datasets' +import { useMembers } from '@/service/use-common' +import { RETRIEVE_METHOD, type RetrievalConfig } from '@/types/app' +import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' + +const mockNotify = vi.fn() +const mockOnCancel = vi.fn() +const mockOnSave = vi.fn() +const mockSetShowAccountSettingModal = vi.fn() +let mockIsWorkspaceDatasetOperator = false + +const mockUseModelList = vi.fn() +const mockUseModelListAndDefaultModel = vi.fn() +const mockUseModelListAndDefaultModelAndCurrentProviderAndModel = vi.fn() +const mockUseCurrentProviderAndModel = vi.fn() +const mockCheckShowMultiModalTip = vi.fn() + +vi.mock('ky', () => { + const ky = () => ky + ky.extend = () => ky + ky.create = () => ky + return { __esModule: true, default: ky } +}) + +vi.mock('@/app/components/datasets/create/step-two', () => ({ + __esModule: true, + IndexingType: { + QUALIFIED: 'high_quality', + ECONOMICAL: 'economy', + }, +})) + +vi.mock('@/service/datasets', () => ({ + updateDatasetSetting: vi.fn(), +})) + +vi.mock('@/service/use-common', async () => ({ + __esModule: true, + ...(await vi.importActual('@/service/use-common')), + useMembers: vi.fn(), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ isCurrentWorkspaceDatasetOperator: mockIsWorkspaceDatasetOperator }), + useSelector: (selector: (value: { userProfile: { id: string; name: string; email: string; avatar_url: string } }) => T) => selector({ + userProfile: { + id: 'user-1', + name: 'User One', + email: 'user@example.com', + avatar_url: 'avatar.png', + }, + }), +})) + +vi.mock('@/context/modal-context', () => ({ + useModalContext: () => ({ + setShowAccountSettingModal: mockSetShowAccountSettingModal, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs${path}`, +})) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + modelProviders: [], + textGenerationModelList: [], + supportRetrievalMethods: [ + RETRIEVE_METHOD.semantic, + RETRIEVE_METHOD.fullText, + RETRIEVE_METHOD.hybrid, + RETRIEVE_METHOD.keywordSearch, + ], + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + __esModule: true, + useModelList: (...args: unknown[]) => mockUseModelList(...args), + useModelListAndDefaultModel: (...args: unknown[]) => mockUseModelListAndDefaultModel(...args), + useModelListAndDefaultModelAndCurrentProviderAndModel: (...args: unknown[]) => + mockUseModelListAndDefaultModelAndCurrentProviderAndModel(...args), + useCurrentProviderAndModel: (...args: unknown[]) => mockUseCurrentProviderAndModel(...args), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ + __esModule: true, + default: ({ defaultModel }: { defaultModel?: { provider: string; model: string } }) => ( +
+ {defaultModel ? `${defaultModel.provider}/${defaultModel.model}` : 'no-model'} +
+ ), +})) + +vi.mock('@/app/components/datasets/settings/utils', () => ({ + checkShowMultiModalTip: (...args: unknown[]) => mockCheckShowMultiModalTip(...args), +})) + +const mockUpdateDatasetSetting = updateDatasetSetting as MockedFunction +const mockUseMembers = useMembers as MockedFunction + +const createRetrievalConfig = (overrides: Partial = {}): RetrievalConfig => ({ + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 2, + score_threshold_enabled: false, + score_threshold: 0.5, + reranking_mode: RerankingModeEnum.RerankingModel, + ...overrides, +}) + +const createDataset = (overrides: Partial = {}, retrievalOverrides: Partial = {}): DataSet => { + const retrievalConfig = createRetrievalConfig(retrievalOverrides) + return { + id: 'dataset-id', + name: 'Test Dataset', + indexing_status: 'completed', + icon_info: { + icon: 'icon', + icon_type: 'emoji', + }, + description: 'Description', + permission: DatasetPermission.allTeamMembers, + data_source_type: DataSourceType.FILE, + indexing_technique: IndexingType.QUALIFIED, + author_name: 'Author', + created_by: 'creator', + updated_by: 'updater', + updated_at: 1700000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + total_available_documents: 0, + word_count: 0, + provider: 'internal', + embedding_model: 'embed-model', + embedding_model_provider: 'embed-provider', + embedding_available: true, + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-id', + external_knowledge_api_id: 'ext-api-id', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 2, + score_threshold: 0.5, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + doc_metadata: [], + keyword_number: 10, + pipeline_id: 'pipeline-id', + is_published: false, + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...overrides, + retrieval_model_dict: { + ...retrievalConfig, + ...overrides.retrieval_model_dict, + }, + retrieval_model: { + ...retrievalConfig, + ...overrides.retrieval_model, + }, + } +} + +const renderWithProviders = (dataset: DataSet) => { + return render( + + + , + ) +} + +const createMemberList = (): DataSet['partial_member_list'] => ([ + 'member-2', +]) + +const renderSettingsModal = async (dataset: DataSet) => { + renderWithProviders(dataset) + await waitFor(() => expect(mockUseMembers).toHaveBeenCalled()) +} + +describe('SettingsModal', () => { + beforeEach(() => { + vi.clearAllMocks() + mockIsWorkspaceDatasetOperator = false + mockUseMembers.mockReturnValue({ + data: { + accounts: [ + { + id: 'user-1', + name: 'User One', + email: 'user@example.com', + avatar: 'avatar.png', + avatar_url: 'avatar.png', + status: 'active', + role: 'owner', + }, + { + id: 'member-2', + name: 'Member Two', + email: 'member@example.com', + avatar: 'avatar.png', + avatar_url: 'avatar.png', + status: 'active', + role: 'editor', + }, + ], + }, + } as ReturnType) + mockUseModelList.mockImplementation((type: ModelTypeEnum) => { + if (type === ModelTypeEnum.rerank) { + return { + data: [ + { + provider: 'rerank-provider', + models: [{ model: 'rerank-model' }], + }, + ], + } + } + return { data: [{ provider: 'embed-provider', models: [{ model: 'embed-model' }] }] } + }) + mockUseModelListAndDefaultModel.mockReturnValue({ modelList: [], defaultModel: null }) + mockUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ defaultModel: null, currentModel: null }) + mockUseCurrentProviderAndModel.mockReturnValue({ currentProvider: null, currentModel: null }) + mockCheckShowMultiModalTip.mockReturnValue(false) + mockUpdateDatasetSetting.mockResolvedValue(createDataset()) + }) + + // Rendering and basic field bindings. + describe('Rendering', () => { + it('should render dataset details when dataset is provided', async () => { + // Arrange + const dataset = createDataset() + + // Act + await renderSettingsModal(dataset) + + // Assert + expect(screen.getByPlaceholderText('datasetSettings.form.namePlaceholder')).toHaveValue('Test Dataset') + expect(screen.getByPlaceholderText('datasetSettings.form.descPlaceholder')).toHaveValue('Description') + }) + + it('should show external knowledge info when dataset is external', async () => { + // Arrange + const dataset = createDataset({ + provider: 'external', + external_knowledge_info: { + external_knowledge_id: 'ext-id-123', + external_knowledge_api_id: 'ext-api-id-123', + external_knowledge_api_name: 'External Knowledge API', + external_knowledge_api_endpoint: 'https://api.external.com', + }, + }) + + // Act + await renderSettingsModal(dataset) + + // Assert + expect(screen.getByText('External Knowledge API')).toBeInTheDocument() + expect(screen.getByText('https://api.external.com')).toBeInTheDocument() + expect(screen.getByText('ext-id-123')).toBeInTheDocument() + }) + }) + + // User interactions that update visible state. + describe('Interactions', () => { + it('should call onCancel when cancel button is clicked', async () => { + // Arrange + const user = userEvent.setup() + + // Act + await renderSettingsModal(createDataset()) + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + // Assert + expect(mockOnCancel).toHaveBeenCalledTimes(1) + }) + + it('should update name input when user types', async () => { + // Arrange + const user = userEvent.setup() + await renderSettingsModal(createDataset()) + + const nameInput = screen.getByPlaceholderText('datasetSettings.form.namePlaceholder') + + // Act + await user.clear(nameInput) + await user.type(nameInput, 'New Dataset Name') + + // Assert + expect(nameInput).toHaveValue('New Dataset Name') + }) + + it('should update description input when user types', async () => { + // Arrange + const user = userEvent.setup() + await renderSettingsModal(createDataset()) + + const descriptionInput = screen.getByPlaceholderText('datasetSettings.form.descPlaceholder') + + // Act + await user.clear(descriptionInput) + await user.type(descriptionInput, 'New description') + + // Assert + expect(descriptionInput).toHaveValue('New description') + }) + + it('should show and dismiss retrieval change tip when indexing method changes', async () => { + // Arrange + const user = userEvent.setup() + const dataset = createDataset({ indexing_technique: IndexingType.ECONOMICAL }) + + // Act + await renderSettingsModal(dataset) + await user.click(screen.getByText('datasetCreation.stepTwo.qualified')) + + // Assert + expect(await screen.findByText('appDebug.datasetConfig.retrieveChangeTip')).toBeInTheDocument() + + // Act + await user.click(screen.getByLabelText('close-retrieval-change-tip')) + + // Assert + await waitFor(() => { + expect(screen.queryByText('appDebug.datasetConfig.retrieveChangeTip')).not.toBeInTheDocument() + }) + }) + + it('should open account setting modal when embedding model tip is clicked', async () => { + // Arrange + const user = userEvent.setup() + + // Act + await renderSettingsModal(createDataset()) + await user.click(screen.getByText('datasetSettings.form.embeddingModelTipLink')) + + // Assert + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ payload: ACCOUNT_SETTING_TAB.PROVIDER }) + }) + }) + + // Validation guardrails before saving. + describe('Validation', () => { + it('should block save when dataset name is empty', async () => { + // Arrange + const user = userEvent.setup() + await renderSettingsModal(createDataset()) + + const nameInput = screen.getByPlaceholderText('datasetSettings.form.namePlaceholder') + + // Act + await user.clear(nameInput) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'datasetSettings.form.nameError', + })) + expect(mockUpdateDatasetSetting).not.toHaveBeenCalled() + }) + + it('should block save when reranking is enabled without model', async () => { + // Arrange + const user = userEvent.setup() + mockUseModelList.mockReturnValue({ data: [] }) + const dataset = createDataset({}, createRetrievalConfig({ + reranking_enable: true, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + })) + + // Act + await renderSettingsModal(dataset) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'appDebug.datasetConfig.rerankModelRequired', + })) + expect(mockUpdateDatasetSetting).not.toHaveBeenCalled() + }) + }) + + // Save flows and side effects. + describe('Save', () => { + it('should save internal dataset changes when form is valid', async () => { + // Arrange + const user = userEvent.setup() + const rerankRetrieval = createRetrievalConfig({ + reranking_enable: true, + reranking_model: { + reranking_provider_name: 'rerank-provider', + reranking_model_name: 'rerank-model', + }, + }) + const dataset = createDataset({ + retrieval_model: rerankRetrieval, + retrieval_model_dict: rerankRetrieval, + }) + + // Act + await renderSettingsModal(dataset) + + const nameInput = screen.getByPlaceholderText('datasetSettings.form.namePlaceholder') + await user.clear(nameInput) + await user.type(nameInput, 'Updated Internal Dataset') + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + await waitFor(() => expect(mockUpdateDatasetSetting).toHaveBeenCalled()) + + expect(mockUpdateDatasetSetting).toHaveBeenCalledWith(expect.objectContaining({ + body: expect.objectContaining({ + name: 'Updated Internal Dataset', + permission: DatasetPermission.allTeamMembers, + }), + })) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + message: 'common.actionMsg.modifiedSuccessfully', + })) + expect(mockOnSave).toHaveBeenCalledWith(expect.objectContaining({ + name: 'Updated Internal Dataset', + retrieval_model_dict: expect.objectContaining({ + reranking_enable: true, + }), + })) + }) + + it('should save external dataset changes when partial members configured', async () => { + // Arrange + const user = userEvent.setup() + const dataset = createDataset({ + provider: 'external', + permission: DatasetPermission.partialMembers, + partial_member_list: createMemberList(), + external_retrieval_model: { + top_k: 5, + score_threshold: 0.3, + score_threshold_enabled: true, + }, + }, { + score_threshold_enabled: true, + score_threshold: 0.8, + }) + + // Act + await renderSettingsModal(dataset) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + await waitFor(() => expect(mockUpdateDatasetSetting).toHaveBeenCalled()) + + expect(mockUpdateDatasetSetting).toHaveBeenCalledWith(expect.objectContaining({ + body: expect.objectContaining({ + permission: DatasetPermission.partialMembers, + external_retrieval_model: expect.objectContaining({ + top_k: 5, + }), + partial_member_list: [ + { + user_id: 'member-2', + role: 'editor', + }, + ], + }), + })) + expect(mockOnSave).toHaveBeenCalledWith(expect.objectContaining({ + retrieval_model_dict: expect.objectContaining({ + score_threshold_enabled: true, + score_threshold: 0.8, + }), + })) + }) + + it('should disable save button while saving', async () => { + // Arrange + const user = userEvent.setup() + mockUpdateDatasetSetting.mockImplementation(() => new Promise(resolve => setTimeout(resolve, 100))) + + // Act + await renderSettingsModal(createDataset()) + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(saveButton).toBeDisabled() + }) + + it('should show error toast when save fails', async () => { + // Arrange + const user = userEvent.setup() + mockUpdateDatasetSetting.mockRejectedValue(new Error('API Error')) + + // Act + await renderSettingsModal(createDataset()) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' })) + }) + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index cd6e39011e..c191ff5d46 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -1,13 +1,10 @@ import type { FC } from 'react' -import { useMemo, useRef, useState } from 'react' -import { useMount } from 'ahooks' +import { useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { isEqual } from 'lodash-es' import { RiCloseLine } from '@remixicon/react' -import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import IndexMethod from '@/app/components/datasets/settings/index-method' -import Divider from '@/app/components/base/divider' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' @@ -18,20 +15,17 @@ import { useAppContext } from '@/context/app-context' import { useModalContext } from '@/context/modal-context' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import type { RetrievalConfig } from '@/types/app' -import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' -import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' -import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' -import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import PermissionSelector from '@/app/components/datasets/settings/permission-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { fetchMembers } from '@/service/common' import type { Member } from '@/models/common' import { IndexingType } from '@/app/components/datasets/create/step-two' import { useDocLink } from '@/context/i18n' +import { useMembers } from '@/service/use-common' import { checkShowMultiModalTip } from '@/app/components/datasets/settings/utils' +import { RetrievalChangeTip, RetrievalSection } from './retrieval-section' type SettingsModalProps = { currentDataset: DataSet @@ -68,6 +62,7 @@ const SettingsModal: FC = ({ const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(localeCurrentDataset?.external_retrieval_model.score_threshold_enabled ?? false) const [selectedMemberIDs, setSelectedMemberIDs] = useState(currentDataset.partial_member_list || []) const [memberList, setMemberList] = useState([]) + const { data: membersData } = useMembers() const [indexMethod, setIndexMethod] = useState(currentDataset.indexing_technique) const [retrievalConfig, setRetrievalConfig] = useState(localeCurrentDataset?.retrieval_model_dict as RetrievalConfig) @@ -165,17 +160,12 @@ const SettingsModal: FC = ({ } } - const getMembers = async () => { - const { accounts } = await fetchMembers({ url: '/workspaces/current/members', params: {} }) - if (!accounts) + useEffect(() => { + if (!membersData?.accounts) setMemberList([]) else - setMemberList(accounts) - } - - useMount(() => { - getMembers() - }) + setMemberList(membersData.accounts) + }, [membersData]) const showMultiModalTip = useMemo(() => { return checkShowMultiModalTip({ @@ -298,92 +288,37 @@ const SettingsModal: FC = ({ )} {/* Retrieval Method Config */} - {currentDataset?.provider === 'external' - ? <> -
-
-
-
{t('datasetSettings.form.retrievalSetting.title')}
-
- -
-
-
-
-
{t('datasetSettings.form.externalKnowledgeAPI')}
-
-
-
- -
- {currentDataset?.external_knowledge_info.external_knowledge_api_name} -
-
·
-
{currentDataset?.external_knowledge_info.external_knowledge_api_endpoint}
-
-
-
-
-
-
{t('datasetSettings.form.externalKnowledgeID')}
-
-
-
-
{currentDataset?.external_knowledge_info.external_knowledge_id}
-
-
-
-
- - :
-
-
-
{t('datasetSettings.form.retrievalSetting.title')}
-
- {t('datasetSettings.form.retrievalSetting.learnMore')} - {t('datasetSettings.form.retrievalSetting.description')} -
-
-
-
- {indexMethod === IndexingType.QUALIFIED - ? ( - - ) - : ( - - )} -
-
} + {isExternal ? ( + + ) : ( + + )}
- {isRetrievalChanged && !isHideChangedTip && ( -
-
- -
{t('appDebug.datasetConfig.retrieveChangeTip')}
-
-
{ - setIsHideChangedTip(true) - e.stopPropagation() - e.nativeEvent.stopImmediatePropagation() - }}> - -
-
- )} + setIsHideChangedTip(true)} + />
{ + const ky = () => ky + ky.extend = () => ky + ky.create = () => ky + return { __esModule: true, default: ky } +}) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + modelProviders: [], + textGenerationModelList: [], + supportRetrievalMethods: [ + RETRIEVE_METHOD.semantic, + RETRIEVE_METHOD.fullText, + RETRIEVE_METHOD.hybrid, + RETRIEVE_METHOD.keywordSearch, + ], + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + __esModule: true, + useModelListAndDefaultModelAndCurrentProviderAndModel: (...args: unknown[]) => + mockUseModelListAndDefaultModelAndCurrentProviderAndModel(...args), + useModelListAndDefaultModel: (...args: unknown[]) => mockUseModelListAndDefaultModel(...args), + useModelList: (...args: unknown[]) => mockUseModelList(...args), + useCurrentProviderAndModel: (...args: unknown[]) => mockUseCurrentProviderAndModel(...args), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ + __esModule: true, + default: ({ defaultModel }: { defaultModel?: { provider: string; model: string } }) => ( +
+ {defaultModel ? `${defaultModel.provider}/${defaultModel.model}` : 'no-model'} +
+ ), +})) + +vi.mock('@/app/components/datasets/create/step-two', () => ({ + __esModule: true, + IndexingType: { + QUALIFIED: 'high_quality', + ECONOMICAL: 'economy', + }, +})) + +const createRetrievalConfig = (overrides: Partial = {}): RetrievalConfig => ({ + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 2, + score_threshold_enabled: false, + score_threshold: 0.5, + reranking_mode: RerankingModeEnum.RerankingModel, + ...overrides, +}) + +const createDataset = (overrides: Partial = {}, retrievalOverrides: Partial = {}): DataSet => { + const retrievalConfig = createRetrievalConfig(retrievalOverrides) + return { + id: 'dataset-id', + name: 'Test Dataset', + indexing_status: 'completed', + icon_info: { + icon: 'icon', + icon_type: 'emoji', + }, + description: 'Description', + permission: DatasetPermission.allTeamMembers, + data_source_type: DataSourceType.FILE, + indexing_technique: IndexingType.QUALIFIED, + author_name: 'Author', + created_by: 'creator', + updated_by: 'updater', + updated_at: 1700000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + total_available_documents: 0, + word_count: 0, + provider: 'internal', + embedding_model: 'embed-model', + embedding_model_provider: 'embed-provider', + embedding_available: true, + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-id', + external_knowledge_api_id: 'ext-api-id', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 2, + score_threshold: 0.5, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + doc_metadata: [], + keyword_number: 10, + pipeline_id: 'pipeline-id', + is_published: false, + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...overrides, + retrieval_model_dict: { + ...retrievalConfig, + ...overrides.retrieval_model_dict, + }, + retrieval_model: { + ...retrievalConfig, + ...overrides.retrieval_model, + }, + } +} + +describe('RetrievalChangeTip', () => { + const defaultProps = { + visible: true, + message: 'Test message', + onDismiss: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders and supports dismiss', async () => { + // Arrange + const onDismiss = vi.fn() + render() + + // Act + await userEvent.click(screen.getByRole('button', { name: 'close-retrieval-change-tip' })) + + // Assert + expect(screen.getByText('Test message')).toBeInTheDocument() + expect(onDismiss).toHaveBeenCalledTimes(1) + }) + + it('does not render when hidden', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByText('Test message')).not.toBeInTheDocument() + }) +}) + +describe('RetrievalSection', () => { + const t = (key: string) => key + const rowClass = 'row' + const labelClass = 'label' + + beforeEach(() => { + vi.clearAllMocks() + mockUseModelList.mockImplementation((type: ModelTypeEnum) => { + if (type === ModelTypeEnum.rerank) + return { data: [{ provider: 'rerank-provider', models: [{ model: 'rerank-model' }] }] } + return { data: [] } + }) + mockUseModelListAndDefaultModel.mockReturnValue({ modelList: [], defaultModel: null }) + mockUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ defaultModel: null, currentModel: null }) + mockUseCurrentProviderAndModel.mockReturnValue({ currentProvider: null, currentModel: null }) + }) + + it('renders external retrieval details and propagates changes', async () => { + // Arrange + const dataset = createDataset({ + provider: 'external', + external_knowledge_info: { + external_knowledge_id: 'ext-id-999', + external_knowledge_api_id: 'ext-api-id-999', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.external.com', + }, + }) + const handleExternalChange = vi.fn() + + // Act + render( + , + ) + const [topKIncrement] = screen.getAllByLabelText('increment') + await userEvent.click(topKIncrement) + + // Assert + expect(screen.getByText('External API')).toBeInTheDocument() + expect(screen.getByText('https://api.external.com')).toBeInTheDocument() + expect(screen.getByText('ext-id-999')).toBeInTheDocument() + expect(handleExternalChange).toHaveBeenCalledWith(expect.objectContaining({ top_k: 4 })) + }) + + it('renders internal retrieval config with doc link', () => { + // Arrange + const docLink = vi.fn((path: string) => `https://docs.example${path}`) + const retrievalConfig = createRetrievalConfig() + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + const learnMoreLink = screen.getByRole('link', { name: 'datasetSettings.form.retrievalSetting.learnMore' }) + expect(learnMoreLink).toHaveAttribute('href', 'https://docs.example/guides/knowledge-base/create-knowledge-and-upload-documents/setting-indexing-methods#setting-the-retrieval-setting') + expect(docLink).toHaveBeenCalledWith('/guides/knowledge-base/create-knowledge-and-upload-documents/setting-indexing-methods#setting-the-retrieval-setting') + }) + + it('propagates retrieval config changes for economical indexing', async () => { + // Arrange + const handleRetrievalChange = vi.fn() + + // Act + render( + path} + />, + ) + const [topKIncrement] = screen.getAllByLabelText('increment') + await userEvent.click(topKIncrement) + + // Assert + expect(screen.getByText('dataset.retrieval.keyword_search.title')).toBeInTheDocument() + expect(handleRetrievalChange).toHaveBeenCalledWith(expect.objectContaining({ + top_k: 3, + })) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/retrieval-section.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/retrieval-section.tsx new file mode 100644 index 0000000000..99d042f681 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/settings-modal/retrieval-section.tsx @@ -0,0 +1,218 @@ +import { RiCloseLine } from '@remixicon/react' +import type { FC } from 'react' +import { cn } from '@/utils/classnames' +import Divider from '@/app/components/base/divider' +import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' +import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' +import type { DataSet } from '@/models/datasets' +import { IndexingType } from '@/app/components/datasets/create/step-two' +import type { RetrievalConfig } from '@/types/app' +import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' +import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' + +type CommonSectionProps = { + rowClass: string + labelClass: string + t: (key: string, options?: any) => string +} + +type ExternalRetrievalSectionProps = CommonSectionProps & { + topK: number + scoreThreshold: number + scoreThresholdEnabled: boolean + onExternalSettingChange: (data: { top_k?: number; score_threshold?: number; score_threshold_enabled?: boolean }) => void + currentDataset: DataSet +} + +const ExternalRetrievalSection: FC = ({ + rowClass, + labelClass, + t, + topK, + scoreThreshold, + scoreThresholdEnabled, + onExternalSettingChange, + currentDataset, +}) => ( + <> +
+
+
+
{t('datasetSettings.form.retrievalSetting.title')}
+
+ +
+
+
+
+
{t('datasetSettings.form.externalKnowledgeAPI')}
+
+
+
+ +
+ {currentDataset?.external_knowledge_info.external_knowledge_api_name} +
+
·
+
{currentDataset?.external_knowledge_info.external_knowledge_api_endpoint}
+
+
+
+
+
+
{t('datasetSettings.form.externalKnowledgeID')}
+
+
+
+
{currentDataset?.external_knowledge_info.external_knowledge_id}
+
+
+
+
+ +) + +type InternalRetrievalSectionProps = CommonSectionProps & { + indexMethod: IndexingType + retrievalConfig: RetrievalConfig + showMultiModalTip: boolean + onRetrievalConfigChange: (value: RetrievalConfig) => void + docLink: (path: string) => string +} + +const InternalRetrievalSection: FC = ({ + rowClass, + labelClass, + t, + indexMethod, + retrievalConfig, + showMultiModalTip, + onRetrievalConfigChange, + docLink, +}) => ( +
+
+
+
{t('datasetSettings.form.retrievalSetting.title')}
+
+ {t('datasetSettings.form.retrievalSetting.learnMore')} + {t('datasetSettings.form.retrievalSetting.description')} +
+
+
+
+ {indexMethod === IndexingType.QUALIFIED + ? ( + + ) + : ( + + )} +
+
+) + +type RetrievalSectionProps + = | (ExternalRetrievalSectionProps & { isExternal: true }) + | (InternalRetrievalSectionProps & { isExternal: false }) + +export const RetrievalSection: FC = (props) => { + if (props.isExternal) { + const { + rowClass, + labelClass, + t, + topK, + scoreThreshold, + scoreThresholdEnabled, + onExternalSettingChange, + currentDataset, + } = props + + return ( + + ) + } + + const { + rowClass, + labelClass, + t, + indexMethod, + retrievalConfig, + showMultiModalTip, + onRetrievalConfigChange, + docLink, + } = props + + return ( + + ) +} + +type RetrievalChangeTipProps = { + visible: boolean + message: string + onDismiss: () => void +} + +export const RetrievalChangeTip: FC = ({ + visible, + message, + onDismiss, +}) => { + if (!visible) + return null + + return ( +
+
+ +
{message}
+
+ +
+ ) +} diff --git a/web/app/components/app/configuration/debug/chat-user-input.tsx b/web/app/components/app/configuration/debug/chat-user-input.tsx index 16666d514e..c25bed548c 100644 --- a/web/app/components/app/configuration/debug/chat-user-input.tsx +++ b/web/app/components/app/configuration/debug/chat-user-input.tsx @@ -7,7 +7,7 @@ import Select from '@/app/components/base/select' import Textarea from '@/app/components/base/textarea' import { DEFAULT_VALUE_MAX_LEN } from '@/config' import type { Inputs } from '@/models/debug' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' type Props = { diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx index 140a6c2e6e..b05f33faff 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx @@ -1,4 +1,3 @@ -import '@testing-library/jest-dom' import type { CSSProperties } from 'react' import { fireEvent, render, screen } from '@testing-library/react' import DebugWithMultipleModel from './index' @@ -18,12 +17,12 @@ type PromptVariableWithMeta = Omit & { hide?: boolean } -const mockUseDebugConfigurationContext = jest.fn() -const mockUseFeaturesSelector = jest.fn() -const mockUseEventEmitterContext = jest.fn() -const mockUseAppStoreSelector = jest.fn() -const mockEventEmitter = { emit: jest.fn() } -const mockSetShowAppConfigureFeaturesModal = jest.fn() +const mockUseDebugConfigurationContext = vi.fn() +const mockUseFeaturesSelector = vi.fn() +const mockUseEventEmitterContext = vi.fn() +const mockUseAppStoreSelector = vi.fn() +const mockEventEmitter = { emit: vi.fn() } +const mockSetShowAppConfigureFeaturesModal = vi.fn() let capturedChatInputProps: MockChatInputAreaProps | null = null let modelIdCounter = 0 let featureState: FeatureStoreState @@ -51,27 +50,27 @@ const mockFiles: FileEntity[] = [ }, ] -jest.mock('@/context/debug-configuration', () => ({ +vi.mock('@/context/debug-configuration', () => ({ __esModule: true, useDebugConfigurationContext: () => mockUseDebugConfigurationContext(), })) -jest.mock('@/app/components/base/features/hooks', () => ({ +vi.mock('@/app/components/base/features/hooks', () => ({ __esModule: true, useFeatures: (selector: (state: FeatureStoreState) => unknown) => mockUseFeaturesSelector(selector), })) -jest.mock('@/context/event-emitter', () => ({ +vi.mock('@/context/event-emitter', () => ({ __esModule: true, useEventEmitterContextContext: () => mockUseEventEmitterContext(), })) -jest.mock('@/app/components/app/store', () => ({ +vi.mock('@/app/components/app/store', () => ({ __esModule: true, useStore: (selector: (state: { setShowAppConfigureFeaturesModal: typeof mockSetShowAppConfigureFeaturesModal }) => unknown) => mockUseAppStoreSelector(selector), })) -jest.mock('./debug-item', () => ({ +vi.mock('./debug-item', () => ({ __esModule: true, default: ({ modelAndParameter, @@ -93,7 +92,7 @@ jest.mock('./debug-item', () => ({ ), })) -jest.mock('@/app/components/base/chat/chat/chat-input-area', () => ({ +vi.mock('@/app/components/base/chat/chat/chat-input-area', () => ({ __esModule: true, default: (props: MockChatInputAreaProps) => { capturedChatInputProps = props @@ -118,9 +117,9 @@ const createFeatureState = (): FeatureStoreState => ({ }, }, }, - setFeatures: jest.fn(), + setFeatures: vi.fn(), showFeaturesModal: false, - setShowFeaturesModal: jest.fn(), + setShowFeaturesModal: vi.fn(), }) const createModelConfig = (promptVariables: PromptVariableWithMeta[] = []): ModelConfig => ({ @@ -178,8 +177,8 @@ const createModelAndParameter = (overrides: Partial = {}): Mo const createProps = (overrides: Partial = {}): DebugWithMultipleModelContextType => ({ multipleModelConfigs: [createModelAndParameter()], - onMultipleModelConfigsChange: jest.fn(), - onDebugWithMultipleModelChange: jest.fn(), + onMultipleModelConfigsChange: vi.fn(), + onDebugWithMultipleModelChange: vi.fn(), ...overrides, }) @@ -190,7 +189,7 @@ const renderComponent = (props?: Partial) => describe('DebugWithMultipleModel', () => { beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() capturedChatInputProps = null modelIdCounter = 0 featureState = createFeatureState() @@ -274,7 +273,7 @@ describe('DebugWithMultipleModel', () => { describe('props and callbacks', () => { it('should call onMultipleModelConfigsChange when provided', () => { - const onMultipleModelConfigsChange = jest.fn() + const onMultipleModelConfigsChange = vi.fn() renderComponent({ onMultipleModelConfigsChange }) // Context provider should pass through the callback @@ -282,7 +281,7 @@ describe('DebugWithMultipleModel', () => { }) it('should call onDebugWithMultipleModelChange when provided', () => { - const onDebugWithMultipleModelChange = jest.fn() + const onDebugWithMultipleModelChange = vi.fn() renderComponent({ onDebugWithMultipleModelChange }) // Context provider should pass through the callback @@ -478,7 +477,7 @@ describe('DebugWithMultipleModel', () => { describe('sending flow', () => { it('should emit chat event when allowed to send', () => { // Arrange - const checkCanSend = jest.fn(() => true) + const checkCanSend = vi.fn(() => true) const multipleModelConfigs = [createModelAndParameter(), createModelAndParameter()] renderComponent({ multipleModelConfigs, checkCanSend }) @@ -512,7 +511,7 @@ describe('DebugWithMultipleModel', () => { it('should block sending when checkCanSend returns false', () => { // Arrange - const checkCanSend = jest.fn(() => false) + const checkCanSend = vi.fn(() => false) renderComponent({ checkCanSend }) // Act @@ -564,8 +563,8 @@ describe('DebugWithMultipleModel', () => { })} />) const twoItems = screen.getAllByTestId('debug-item') - expect(twoItems[0].style.width).toBe('calc(50% - 4px - 24px)') - expect(twoItems[1].style.width).toBe('calc(50% - 4px - 24px)') + expect(twoItems[0].style.width).toBe('calc(50% - 28px)') + expect(twoItems[1].style.width).toBe('calc(50% - 28px)') }) }) @@ -604,13 +603,13 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(2) expectItemLayout(items[0], { - width: 'calc(50% - 4px - 24px)', + width: 'calc(50% - 28px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 4px - 24px)', + width: 'calc(50% - 28px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: [], @@ -628,19 +627,19 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(3) expectItemLayout(items[0], { - width: 'calc(33.3% - 5.33px - 16px)', + width: 'calc(33.3% - 21.33px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(33.3% - 5.33px - 16px)', + width: 'calc(33.3% - 21.33px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[2], { - width: 'calc(33.3% - 5.33px - 16px)', + width: 'calc(33.3% - 21.33px)', height: '100%', transform: 'translateX(calc(200% + 16px)) translateY(0)', classes: [], @@ -663,25 +662,25 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(4) expectItemLayout(items[0], { - width: 'calc(50% - 4px - 24px)', + width: 'calc(50% - 28px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(0)', classes: ['mr-2', 'mb-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 4px - 24px)', + width: 'calc(50% - 28px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mb-2'], }) expectItemLayout(items[2], { - width: 'calc(50% - 4px - 24px)', + width: 'calc(50% - 28px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(calc(100% + 8px))', classes: ['mr-2'], }) expectItemLayout(items[3], { - width: 'calc(50% - 4px - 24px)', + width: 'calc(50% - 28px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(calc(100% + 8px))', classes: [], diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx new file mode 100644 index 0000000000..bca65387e7 --- /dev/null +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx @@ -0,0 +1,1013 @@ +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { type ReactNode, type RefObject, createRef } from 'react' +import DebugWithSingleModel from './index' +import type { DebugWithSingleModelRefType } from './index' +import type { ChatItem } from '@/app/components/base/chat/types' +import { ConfigurationMethodEnum, ModelFeatureEnum, ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { ProviderContextState } from '@/context/provider-context' +import type { DatasetConfigs, ModelConfig } from '@/models/debug' +import { PromptMode } from '@/models/debug' +import { type Collection, CollectionType } from '@/app/components/tools/types' +import type { FileEntity } from '@/app/components/base/file-uploader/types' +import { AgentStrategy, AppModeEnum, ModelModeType, Resolution, TransferMethod } from '@/types/app' + +// ============================================================================ +// Test Data Factories (Following testing.md guidelines) +// ============================================================================ + +/** + * Factory function for creating mock ModelConfig with type safety + */ +function createMockModelConfig(overrides: Partial = {}): ModelConfig { + return { + provider: 'openai', + model_id: 'gpt-3.5-turbo', + mode: ModelModeType.chat, + configs: { + prompt_template: 'Test template', + prompt_variables: [ + { key: 'var1', name: 'Variable 1', type: 'text', required: false }, + ], + }, + chat_prompt_config: { + prompt: [], + }, + completion_prompt_config: { + prompt: { text: '' }, + conversation_histories_role: { + user_prefix: 'user', + assistant_prefix: 'assistant', + }, + }, + more_like_this: null, + opening_statement: '', + suggested_questions: [], + sensitive_word_avoidance: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null, + suggested_questions_after_answer: null, + retriever_resource: null, + annotation_reply: null, + external_data_tools: [], + system_parameters: { + audio_file_size_limit: 0, + file_size_limit: 0, + image_file_size_limit: 0, + video_file_size_limit: 0, + workflow_file_upload_limit: 0, + }, + dataSets: [], + agentConfig: { + enabled: false, + max_iteration: 5, + tools: [], + strategy: AgentStrategy.react, + }, + ...overrides, + } +} + +/** + * Factory function for creating mock Collection list + */ +function createMockCollections(collections: Partial[] = []): Collection[] { + return collections.map((collection, index) => ({ + id: `collection-${index}`, + name: `Collection ${index}`, + icon: 'icon-url', + type: 'tool', + ...collection, + } as Collection)) +} + +/** + * Factory function for creating mock Provider Context + */ +function createMockProviderContext(overrides: Partial = {}): ProviderContextState { + return { + textGenerationModelList: [ + { + provider: 'openai', + label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' }, + icon_small: { en_US: 'icon', zh_Hans: 'icon' }, + icon_large: { en_US: 'icon', zh_Hans: 'icon' }, + status: ModelStatusEnum.active, + models: [ + { + model: 'gpt-3.5-turbo', + label: { en_US: 'GPT-3.5', zh_Hans: 'GPT-3.5' }, + model_type: ModelTypeEnum.textGeneration, + features: [ModelFeatureEnum.vision], + fetch_from: ConfigurationMethodEnum.predefinedModel, + model_properties: {}, + deprecated: false, + }, + ], + }, + ], + hasSettedApiKey: true, + modelProviders: [], + speech2textDefaultModel: null, + ttsDefaultModel: null, + agentThoughtDefaultModel: null, + updateModelList: vi.fn(), + onPlanInfoChanged: vi.fn(), + refreshModelProviders: vi.fn(), + refreshLicenseLimit: vi.fn(), + ...overrides, + } as ProviderContextState +} + +// ============================================================================ +// Mock External Dependencies ONLY (Following testing.md guidelines) +// ============================================================================ + +// Mock service layer (API calls) +const { mockSsePost } = vi.hoisted(() => ({ + mockSsePost: vi.fn<(...args: any[]) => Promise>(() => Promise.resolve()), +})) + +vi.mock('@/service/base', () => ({ + ssePost: mockSsePost, + post: vi.fn(() => Promise.resolve({ data: {} })), + get: vi.fn(() => Promise.resolve({ data: {} })), + del: vi.fn(() => Promise.resolve({ data: {} })), + patch: vi.fn(() => Promise.resolve({ data: {} })), + put: vi.fn(() => Promise.resolve({ data: {} })), +})) + +vi.mock('@/service/fetch', () => ({ + fetch: vi.fn(() => Promise.resolve({ ok: true, json: () => Promise.resolve({}) })), +})) + +const { mockFetchConversationMessages, mockFetchSuggestedQuestions, mockStopChatMessageResponding } = vi.hoisted(() => ({ + mockFetchConversationMessages: vi.fn(), + mockFetchSuggestedQuestions: vi.fn(), + mockStopChatMessageResponding: vi.fn(), +})) + +vi.mock('@/service/debug', () => ({ + fetchConversationMessages: mockFetchConversationMessages, + fetchSuggestedQuestions: mockFetchSuggestedQuestions, + stopChatMessageResponding: mockStopChatMessageResponding, +})) + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ push: vi.fn() }), + usePathname: () => '/test', + useParams: () => ({}), +})) + +// Mock complex context providers +const mockDebugConfigContext = { + appId: 'test-app-id', + isAPIKeySet: true, + isTrailFinished: false, + mode: AppModeEnum.CHAT, + modelModeType: ModelModeType.chat, + promptMode: PromptMode.simple, + setPromptMode: vi.fn(), + isAdvancedMode: false, + isAgent: false, + isFunctionCall: false, + isOpenAI: true, + collectionList: createMockCollections([ + { id: 'test-provider', name: 'Test Tool', icon: 'icon-url' }, + ]), + canReturnToSimpleMode: false, + setCanReturnToSimpleMode: vi.fn(), + chatPromptConfig: {}, + completionPromptConfig: {}, + currentAdvancedPrompt: [], + showHistoryModal: vi.fn(), + conversationHistoriesRole: { user_prefix: 'user', assistant_prefix: 'assistant' }, + setConversationHistoriesRole: vi.fn(), + setCurrentAdvancedPrompt: vi.fn(), + hasSetBlockStatus: { context: false, history: false, query: false }, + conversationId: null, + setConversationId: vi.fn(), + introduction: '', + setIntroduction: vi.fn(), + suggestedQuestions: [], + setSuggestedQuestions: vi.fn(), + controlClearChatMessage: 0, + setControlClearChatMessage: vi.fn(), + prevPromptConfig: { prompt_template: '', prompt_variables: [] }, + setPrevPromptConfig: vi.fn(), + moreLikeThisConfig: { enabled: false }, + setMoreLikeThisConfig: vi.fn(), + suggestedQuestionsAfterAnswerConfig: { enabled: false }, + setSuggestedQuestionsAfterAnswerConfig: vi.fn(), + speechToTextConfig: { enabled: false }, + setSpeechToTextConfig: vi.fn(), + textToSpeechConfig: { enabled: false, voice: '', language: '' }, + setTextToSpeechConfig: vi.fn(), + citationConfig: { enabled: false }, + setCitationConfig: vi.fn(), + moderationConfig: { enabled: false }, + annotationConfig: { id: '', enabled: false, score_threshold: 0.7, embedding_model: { embedding_model_name: '', embedding_provider_name: '' } }, + setAnnotationConfig: vi.fn(), + setModerationConfig: vi.fn(), + externalDataToolsConfig: [], + setExternalDataToolsConfig: vi.fn(), + formattingChanged: false, + setFormattingChanged: vi.fn(), + inputs: { var1: 'test input' }, + setInputs: vi.fn(), + query: '', + setQuery: vi.fn(), + completionParams: { max_tokens: 100, temperature: 0.7 }, + setCompletionParams: vi.fn(), + modelConfig: createMockModelConfig({ + agentConfig: { + enabled: false, + max_iteration: 5, + tools: [{ + tool_name: 'test-tool', + provider_id: 'test-provider', + provider_type: CollectionType.builtIn, + provider_name: 'test-provider', + tool_label: 'Test Tool', + tool_parameters: {}, + enabled: true, + }], + strategy: AgentStrategy.react, + }, + }), + setModelConfig: vi.fn(), + dataSets: [], + showSelectDataSet: vi.fn(), + setDataSets: vi.fn(), + datasetConfigs: { + retrieval_model: 'single', + reranking_model: { reranking_provider_name: '', reranking_model_name: '' }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + datasets: { datasets: [] }, + } as DatasetConfigs, + datasetConfigsRef: createRef(), + setDatasetConfigs: vi.fn(), + hasSetContextVar: false, + isShowVisionConfig: false, + visionConfig: { enabled: false, number_limits: 2, detail: Resolution.low, transfer_methods: [] }, + setVisionConfig: vi.fn(), + isAllowVideoUpload: false, + isShowDocumentConfig: false, + isShowAudioConfig: false, + rerankSettingModalOpen: false, + setRerankSettingModalOpen: vi.fn(), +} + +const { mockUseDebugConfigurationContext } = vi.hoisted(() => ({ + mockUseDebugConfigurationContext: vi.fn(), +})) + +// Set up the default implementation after mockDebugConfigContext is defined +mockUseDebugConfigurationContext.mockReturnValue(mockDebugConfigContext) + +vi.mock('@/context/debug-configuration', () => ({ + useDebugConfigurationContext: mockUseDebugConfigurationContext, +})) + +const mockProviderContext = createMockProviderContext() + +const { mockUseProviderContext } = vi.hoisted(() => ({ + mockUseProviderContext: vi.fn(), +})) + +mockUseProviderContext.mockReturnValue(mockProviderContext) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: mockUseProviderContext, +})) + +const mockAppContext = { + userProfile: { + id: 'user-1', + avatar_url: 'https://example.com/avatar.png', + name: 'Test User', + email: 'test@example.com', + }, + isCurrentWorkspaceManager: false, + isCurrentWorkspaceOwner: false, + isCurrentWorkspaceDatasetOperator: false, + mutateUserProfile: vi.fn(), +} + +const { mockUseAppContext } = vi.hoisted(() => ({ + mockUseAppContext: vi.fn(), +})) + +mockUseAppContext.mockReturnValue(mockAppContext) + +vi.mock('@/context/app-context', () => ({ + useAppContext: mockUseAppContext, +})) + +type FeatureState = { + moreLikeThis: { enabled: boolean } + opening: { enabled: boolean; opening_statement: string; suggested_questions: string[] } + moderation: { enabled: boolean } + speech2text: { enabled: boolean } + text2speech: { enabled: boolean } + file: { enabled: boolean } + suggested: { enabled: boolean } + citation: { enabled: boolean } + annotationReply: { enabled: boolean } +} + +const defaultFeatures: FeatureState = { + moreLikeThis: { enabled: false }, + opening: { enabled: false, opening_statement: '', suggested_questions: [] }, + moderation: { enabled: false }, + speech2text: { enabled: false }, + text2speech: { enabled: false }, + file: { enabled: false }, + suggested: { enabled: false }, + citation: { enabled: false }, + annotationReply: { enabled: false }, +} +type FeatureSelector = (state: { features: FeatureState }) => unknown + +let mockFeaturesState: FeatureState = { ...defaultFeatures } + +const { mockUseFeatures } = vi.hoisted(() => ({ + mockUseFeatures: vi.fn(), +})) + +vi.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: mockUseFeatures, +})) + +const mockConfigFromDebugContext = { + pre_prompt: 'Test prompt', + prompt_type: 'simple', + user_input_form: [], + dataset_query_variable: '', + opening_statement: '', + more_like_this: { enabled: false }, + suggested_questions: [], + suggested_questions_after_answer: { enabled: false }, + text_to_speech: { enabled: false }, + speech_to_text: { enabled: false }, + retriever_resource: { enabled: false }, + sensitive_word_avoidance: { enabled: false }, + agent_mode: {}, + dataset_configs: {}, + file_upload: { enabled: false }, + annotation_reply: { enabled: false }, + supportAnnotation: true, + appId: 'test-app-id', + supportCitationHitInfo: true, +} + +const { mockUseConfigFromDebugContext, mockUseFormattingChangedSubscription } = vi.hoisted(() => ({ + mockUseConfigFromDebugContext: vi.fn(), + mockUseFormattingChangedSubscription: vi.fn(), +})) + +mockUseConfigFromDebugContext.mockReturnValue(mockConfigFromDebugContext) + +vi.mock('../hooks', () => ({ + useConfigFromDebugContext: mockUseConfigFromDebugContext, + useFormattingChangedSubscription: mockUseFormattingChangedSubscription, +})) + +const mockSetShowAppConfigureFeaturesModal = vi.fn() + +vi.mock('@/app/components/app/store', () => ({ + useStore: vi.fn((selector?: (state: { setShowAppConfigureFeaturesModal: typeof mockSetShowAppConfigureFeaturesModal }) => unknown) => { + if (typeof selector === 'function') + return selector({ setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal }) + return mockSetShowAppConfigureFeaturesModal + }), +})) + +// Mock event emitter context +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: vi.fn(() => ({ + eventEmitter: null, + })), +})) + +// Mock toast context +vi.mock('@/app/components/base/toast', () => ({ + useToastContext: vi.fn(() => ({ + notify: vi.fn(), + })), +})) + +// Mock hooks/use-timestamp +vi.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: vi.fn(() => ({ + formatTime: vi.fn((timestamp: number) => new Date(timestamp).toLocaleString()), + })), +})) + +// Mock audio player manager +vi.mock('@/app/components/base/audio-btn/audio.player.manager', () => ({ + AudioPlayerManager: { + getInstance: vi.fn(() => ({ + getAudioPlayer: vi.fn(), + resetAudioPlayer: vi.fn(), + })), + }, +})) + +type MockChatProps = { + chatList?: ChatItem[] + isResponding?: boolean + onSend?: (message: string, files?: FileEntity[]) => void + onRegenerate?: (chatItem: ChatItem, editedQuestion?: { message: string; files?: FileEntity[] }) => void + onStopResponding?: () => void + suggestedQuestions?: string[] + questionIcon?: ReactNode + answerIcon?: ReactNode + onAnnotationAdded?: (annotationId: string, authorName: string, question: string, answer: string, index: number) => void + onAnnotationEdited?: (question: string, answer: string, index: number) => void + onAnnotationRemoved?: (index: number) => void + switchSibling?: (siblingMessageId: string) => void + onFeatureBarClick?: (state: boolean) => void +} + +const mockFile: FileEntity = { + id: 'file-1', + name: 'test.png', + size: 123, + type: 'image/png', + progress: 100, + transferMethod: TransferMethod.local_file, + supportFileType: 'image', +} + +// Mock Chat component (complex with many dependencies) +// This is a pragmatic mock that tests the integration at DebugWithSingleModel level +vi.mock('@/app/components/base/chat/chat', () => ({ + default: function MockChat({ + chatList, + isResponding, + onSend, + onRegenerate, + onStopResponding, + suggestedQuestions, + questionIcon, + answerIcon, + onAnnotationAdded, + onAnnotationEdited, + onAnnotationRemoved, + switchSibling, + onFeatureBarClick, + }: MockChatProps) { + const items = chatList || [] + const suggested = suggestedQuestions ?? [] + return ( +
+
+ {items.map((item: ChatItem) => ( +
+ {item.content} +
+ ))} +
+ {questionIcon &&
{questionIcon}
} + {answerIcon &&
{answerIcon}
} +